Skip to content

Commit

Permalink
plonk: multiprover: constraint-system: Implement MpcCircuit
Browse files Browse the repository at this point in the history
  • Loading branch information
joeykraut committed Oct 21, 2023
1 parent be21dbe commit 46795b4
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 21 deletions.
1 change: 1 addition & 0 deletions plonk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ ark-mpc = { git = "https://github.com/renegade-fi/ark-mpc.git" }
ark-poly = "0.4.2"
ark-serialize = "0.4.0"
ark-std = { version = "0.4.0", default-features = false }
async-trait = "0.1"
derivative = { version = "2", features = ["use_core"] }
displaydoc = { version = "0.2.3", default-features = false }
downcast-rs = { version = "1.2.0", default-features = false }
Expand Down
153 changes: 134 additions & 19 deletions plonk/src/multiprover/proof_system/constraint_system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,16 @@ use ark_mpc::{
MpcFabric,
};
use ark_poly::{univariate::DensePolynomial, EvaluationDomain, Radix2EvaluationDomain};
use async_trait::async_trait;
use futures::future;
use itertools::Itertools;
use jf_relation::{
constants::{GATE_WIDTH, N_MUL_SELECTORS},
errors::CircuitError,
gates::{Gate, IoGate, PaddingGate},
gates::{
AdditionGate, BoolGate, ConstantGate, EqualityGate, Gate, IoGate, MultiplicationGate,
PaddingGate, SubtractionGate,
},
GateId, Variable, WireId,
};

Expand Down Expand Up @@ -50,6 +55,7 @@ impl MpcBoolVar {
///
/// This is largely a re-implementation of the existing `Circuit` trait, made to
/// work over a secret shared field
#[async_trait]
pub trait MpcCircuit<C: CurveGroup> {
/// The number of constraints
fn num_gates(&self) -> usize;
Expand All @@ -72,7 +78,7 @@ pub trait MpcCircuit<C: CurveGroup> {
fn public_input(&self) -> Result<Vec<AuthenticatedScalarResult<C>>, MpcCircuitError>;

/// Check whether the circuit constraints are satisfied
fn check_circuit_satisfiability(
async fn check_circuit_satisfiability(
&self,
public_input: &[AuthenticatedScalarResult<C>],
) -> Result<(), MpcCircuitError>;
Expand Down Expand Up @@ -674,6 +680,7 @@ impl<C: CurveGroup> MpcPlonkCircuit<C> {
}
}

#[async_trait]
impl<C: CurveGroup> MpcCircuit<C> for MpcPlonkCircuit<C> {
fn num_gates(&self) -> usize {
self.gates.len()
Expand Down Expand Up @@ -703,11 +710,50 @@ impl<C: CurveGroup> MpcCircuit<C> for MpcPlonkCircuit<C> {
// Note: This method involves opening the witness values, it should only be
// used in testing contexts
#[cfg(feature = "test_apis")]
fn check_circuit_satisfiability(
async fn check_circuit_satisfiability(
&self,
public_input: &[AuthenticatedScalarResult<C>],
) -> Result<(), MpcCircuitError> {
unimplemented!()
let n = public_input.len();
if n != self.num_inputs() {
return Err(MpcCircuitError::ConstraintSystem(
CircuitError::PubInputLenMismatch(n, self.pub_input_gate_ids.len()),
));
}

let mut gate_results = Vec::new();

// Check public I/O gates
for (i, gate_id) in self.pub_input_gate_ids.iter().enumerate() {
let pi = &public_input[i];
gate_results.push(self.check_gate(*gate_id, pi));
}

// Check rest of the gates
let zero = self.fabric.zero_authenticated();
for gate_id in 0..self.num_gates() {
if !self.is_io_gate(gate_id) {
let res = self.check_gate(gate_id, &zero /* public_input */);
gate_results.push(res);
}
}

// Await all the gate results
future::join_all(gate_results)
.await
.into_iter()
.enumerate()
.map(|(idx, res)| {
if res == Scalar::zero() {
Ok(())
} else {
Err(MpcCircuitError::ConstraintSystem(
CircuitError::GateCheckFailure(idx, "gate check failed".to_string()),
))
}
})
.collect::<Result<Vec<_>, MpcCircuitError>>()
.map(|_| ())
}

#[cfg(not(feature = "test_apis"))]
Expand All @@ -719,7 +765,7 @@ impl<C: CurveGroup> MpcCircuit<C> for MpcPlonkCircuit<C> {
}

fn create_constant_variable(&mut self, val: Scalar<C>) -> Result<MpcVariable, MpcCircuitError> {
let authenticated_val = self.fabric.one_authenticated() * &val;
let authenticated_val = self.fabric.one_authenticated() * val;
let var = self.create_variable(authenticated_val)?;
self.enforce_constant(var, val)?;

Expand All @@ -730,18 +776,31 @@ impl<C: CurveGroup> MpcCircuit<C> for MpcPlonkCircuit<C> {
&mut self,
val: AuthenticatedScalarResult<C>,
) -> Result<MpcVariable, MpcCircuitError> {
unimplemented!()
self.check_finalize_flag(false)?;
self.witness.push(val);
self.num_vars += 1;

Ok(self.num_vars - 1)
}

fn create_public_variable(
&mut self,
val: AuthenticatedScalarResult<C>,
) -> Result<MpcVariable, MpcCircuitError> {
unimplemented!()
let var = self.create_variable(val)?;
self.set_variable_public(var)?;

Ok(var)
}

fn set_variable_public(&mut self, var: MpcVariable) -> Result<(), MpcCircuitError> {
unimplemented!()
self.check_finalize_flag(false)?;
self.pub_input_gate_ids.push(self.num_gates());

// Create an io gate that forces `witness[var] = public_input`.
let wire_vars = &[0, 0, 0, 0, var];
self.insert_gate(wire_vars, Box::new(IoGate))?;
Ok(())
}

fn zero(&self) -> MpcVariable {
Expand All @@ -753,15 +812,21 @@ impl<C: CurveGroup> MpcCircuit<C> for MpcPlonkCircuit<C> {
}

fn witness(&self, idx: MpcVariable) -> Result<AuthenticatedScalarResult<C>, MpcCircuitError> {
unimplemented!()
self.check_var_bound(idx)?;

Ok(self.witness[idx].clone())
}

fn enforce_constant(
&mut self,
var: MpcVariable,
constant: Scalar<C>,
) -> Result<(), MpcCircuitError> {
unimplemented!()
self.check_var_bound(var)?;

let wire_vars = &[0, 0, 0, 0, var];
self.insert_gate(wire_vars, Box::new(ConstantGate(constant.inner())))?;
Ok(())
}

fn add_gate(
Expand All @@ -770,11 +835,24 @@ impl<C: CurveGroup> MpcCircuit<C> for MpcPlonkCircuit<C> {
b: MpcVariable,
c: MpcVariable,
) -> Result<(), MpcCircuitError> {
unimplemented!()
self.check_var_bound(a)?;
self.check_var_bound(b)?;
self.check_var_bound(c)?;

let wire_vars = &[a, b, 0, 0, c];
self.insert_gate(wire_vars, Box::new(AdditionGate))?;
Ok(())
}

fn add(&mut self, a: MpcVariable, b: MpcVariable) -> Result<MpcVariable, MpcCircuitError> {
unimplemented!()
self.check_var_bound(a)?;
self.check_var_bound(b)?;

let res = self.witness(a)? + self.witness(b)?;
let c = self.create_variable(res)?;

self.add_gate(a, b, c)?;
Ok(c)
}

fn sub_gate(
Expand All @@ -783,11 +861,24 @@ impl<C: CurveGroup> MpcCircuit<C> for MpcPlonkCircuit<C> {
b: MpcVariable,
c: MpcVariable,
) -> Result<(), MpcCircuitError> {
unimplemented!()
self.check_var_bound(a)?;
self.check_var_bound(b)?;
self.check_var_bound(c)?;

let wire_vars = &[a, b, 0, 0, c];
self.insert_gate(wire_vars, Box::new(SubtractionGate))?;
Ok(())
}

fn sub(&mut self, a: MpcVariable, b: MpcVariable) -> Result<MpcVariable, MpcCircuitError> {
unimplemented!()
self.check_var_bound(a)?;
self.check_var_bound(b)?;

let res = self.witness(a)? - self.witness(b)?;
let c = self.create_variable(res)?;

self.sub_gate(a, b, c)?;
Ok(c)
}

fn mul_gate(
Expand All @@ -796,22 +887,46 @@ impl<C: CurveGroup> MpcCircuit<C> for MpcPlonkCircuit<C> {
b: MpcVariable,
c: MpcVariable,
) -> Result<(), MpcCircuitError> {
unimplemented!()
self.check_var_bound(a)?;
self.check_var_bound(b)?;

let wire_vars = &[a, b, 0, 0, c];
self.insert_gate(wire_vars, Box::new(MultiplicationGate))?;
Ok(())
}

fn mul(&mut self, a: MpcVariable, b: MpcVariable) -> Result<MpcVariable, MpcCircuitError> {
unimplemented!()
self.check_var_bound(a)?;
self.check_var_bound(b)?;

let res = self.witness(a)? * self.witness(b)?;
let c = self.create_variable(res)?;

self.mul_gate(a, b, c)?;
Ok(c)
}

fn enforce_bool(&mut self, a: MpcVariable) -> Result<(), MpcCircuitError> {
unimplemented!()
self.check_var_bound(a)?;

let wire_vars = &[a, a, 0, 0, a];
self.insert_gate(wire_vars, Box::new(BoolGate))?;
Ok(())
}

fn enforce_equal(&mut self, a: MpcVariable, b: MpcVariable) -> Result<(), MpcCircuitError> {
unimplemented!()
self.check_var_bound(a)?;
self.check_var_bound(b)?;

let wire_vars = &[a, b, 0, 0, 0];
self.insert_gate(wire_vars, Box::new(EqualityGate))?;
Ok(())
}

fn pad_gates(&mut self, n: usize) {
unimplemented!()
let wire_vars = &[self.zero(), self.zero(), 0, 0, 0];
for _ in 0..n {
self.insert_gate(wire_vars, Box::new(EqualityGate)).unwrap();
}
}
}
2 changes: 1 addition & 1 deletion relation/src/gates/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use ark_ff::Field;

/// A constant gate
#[derive(Debug, Clone)]
pub struct ConstantGate<F: Field>(pub(crate) F);
pub struct ConstantGate<F: Field>(pub F);

impl<F> Gate<F> for ConstantGate<F>
where
Expand Down
2 changes: 1 addition & 1 deletion relation/src/gates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub use logic::*;
pub use lookup::*;

/// Describes a gate with getter for all selectors configuration
pub trait Gate<F: Field>: Downcast + DynClone {
pub trait Gate<F: Field>: Downcast + DynClone + Send + Sync {
/// Get the name of a gate.
fn name(&self) -> &'static str;
/// Selectors for linear combination.
Expand Down

0 comments on commit 46795b4

Please sign in to comment.