From 470bbfa352c87868cbef44cbc03e05053fc2d793 Mon Sep 17 00:00:00 2001 From: Joey Kraut Date: Mon, 9 Oct 2023 18:11:15 -0700 Subject: [PATCH] network, fabric, algebra: Make entire crate generic over curve choice --- Cargo.toml | 3 +- src/algebra/authenticated_curve.rs | 22 ++- src/algebra/authenticated_scalar.rs | 34 ++-- src/algebra/curve.rs | 98 ++++------- src/algebra/macros.rs | 26 +-- src/algebra/mod.rs | 13 +- src/algebra/mpc_curve.rs | 10 +- src/algebra/mpc_scalar.rs | 68 ++++---- src/algebra/scalar.rs | 72 ++++++-- src/beaver.rs | 33 ++-- src/commitment.rs | 53 +++--- src/fabric.rs | 261 +++++++++++++++------------- src/fabric/executor.rs | 37 ++-- src/fabric/network_sender.rs | 21 +-- src/fabric/result.rs | 77 ++++---- src/lib.rs | 28 +-- src/network.rs | 51 +++--- src/network/mock.rs | 44 ++--- src/network/quic.rs | 39 +++-- 19 files changed, 519 insertions(+), 471 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9909ede..8431b33 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ path = "src/lib.rs" [features] benchmarks = [] stats = ["benchmarks"] -test_helpers = [] +test_helpers = ["ark-curve25519"] [[test]] name = "integration" @@ -73,6 +73,7 @@ kanal = "0.1.0-pre8" tokio = { version = "1.12", features = ["macros", "rt-multi-thread"] } # == Arithemtic + Crypto == # +ark-curve25519 = { version = "0.4", optional = true } ark-ec = { version = "0.4", features = ["parallel"] } ark-ff = "0.4" ark-serialize = "0.4" diff --git a/src/algebra/authenticated_curve.rs b/src/algebra/authenticated_curve.rs index b8f03d1..173a234 100644 --- a/src/algebra/authenticated_curve.rs +++ b/src/algebra/authenticated_curve.rs @@ -116,12 +116,13 @@ impl AuthenticatedPointResult { n: usize, ) -> Vec> { // Convert to a set of scalar results - let scalar_results = values - .fabric() - .new_batch_gate_op(vec![values.id()], n, |mut args| { - let args: Vec> = args.pop().unwrap().into(); - args.into_iter().map(ResultValue::Point).collect_vec() - }); + let scalar_results: Vec> = + values + .fabric() + .new_batch_gate_op(vec![values.id()], n, |mut args| { + let args: Vec> = args.pop().unwrap().into(); + args.into_iter().map(ResultValue::Point).collect_vec() + }); Self::new_shared_batch(&scalar_results) } @@ -137,7 +138,7 @@ impl AuthenticatedPointResult { } /// Borrow the fabric that this result is allocated in - pub fn fabric(&self) -> &MpcFabric { + pub fn fabric(&self) -> &MpcFabric { self.share.fabric() } @@ -400,7 +401,10 @@ impl Debug for AuthenticatedPointOpenResult { } } -impl Future for AuthenticatedPointOpenResult { +impl Future for AuthenticatedPointOpenResult +where + C::ScalarField: Unpin, +{ type Output = Result, MpcError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -408,7 +412,7 @@ impl Future for AuthenticatedPointOpenResult { let value = futures::ready!(self.as_mut().value.poll_unpin(cx)); let mac_check = futures::ready!(self.as_mut().mac_check.poll_unpin(cx)); - if mac_check == Scalar::from(1) { + if mac_check == Scalar::from(1u8) { Poll::Ready(Ok(value)) } else { Poll::Ready(Err(MpcError::AuthenticationError)) diff --git a/src/algebra/authenticated_scalar.rs b/src/algebra/authenticated_scalar.rs index b3629ce..49f5b82 100644 --- a/src/algebra/authenticated_scalar.rs +++ b/src/algebra/authenticated_scalar.rs @@ -16,7 +16,7 @@ use crate::{ commitment::{PedersenCommitment, PedersenCommitmentResult}, error::MpcError, fabric::{MpcFabric, ResultId, ResultValue}, - ResultHandle, PARTY0, + PARTY0, }; use super::{ @@ -34,7 +34,7 @@ pub const AUTHENTICATED_SCALAR_RESULT_LEN: usize = 3; /// SPDZ protocol: https://eprint.iacr.org/2011/535.pdf /// that ensures security against a malicious adversary #[derive(Clone)] -pub struct AuthenticatedScalarResult { +pub struct AuthenticatedScalarResult { /// The secret shares of the underlying value pub(crate) share: MpcScalarResult, /// The SPDZ style, unconditionally secure MAC of the value @@ -119,12 +119,13 @@ impl AuthenticatedScalarResult { n: usize, ) -> Vec> { // Convert to a set of scalar results - let scalar_results = values - .fabric() - .new_batch_gate_op(vec![values.id()], n, |mut args| { - let scalars: Vec> = args.pop().unwrap().into(); - scalars.into_iter().map(ResultValue::Scalar).collect() - }); + let scalar_results: Vec> = + values + .fabric() + .new_batch_gate_op(vec![values.id()], n, |mut args| { + let scalars: Vec> = args.pop().unwrap().into(); + scalars.into_iter().map(ResultValue::Scalar).collect() + }); Self::new_shared_batch(&scalar_results) } @@ -141,7 +142,7 @@ impl AuthenticatedScalarResult { } /// Get a reference to the underlying MPC fabric - pub fn fabric(&self) -> &MpcFabric { + pub fn fabric(&self) -> &MpcFabric { self.share.fabric() } @@ -198,7 +199,7 @@ impl AuthenticatedScalarResult { } // Sum of the commitments should be zero - if peer_mac_share + my_mac_share != Scalar::from(0) { + if peer_mac_share + my_mac_share != Scalar::zero() { return false; } @@ -396,7 +397,10 @@ pub struct AuthenticatedScalarOpenResult { pub mac_check: ScalarResult, } -impl Future for AuthenticatedScalarOpenResult { +impl Future for AuthenticatedScalarOpenResult +where + C::ScalarField: Unpin, +{ type Output = Result, MpcError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -404,7 +408,7 @@ impl Future for AuthenticatedScalarOpenResult { let value = futures::ready!(self.as_mut().value.poll_unpin(cx)); let mac_check = futures::ready!(self.as_mut().mac_check.poll_unpin(cx)); - if mac_check == Scalar::from(1) { + if mac_check == Scalar::from(1u8) { Poll::Ready(Ok(value)) } else { Poll::Ready(Err(MpcError::AuthenticationError)) @@ -425,7 +429,7 @@ impl Add<&Scalar> for &AuthenticatedScalarResult { let new_share = if self.fabric().party_id() == PARTY0 { &self.share + rhs } else { - &self.share + Scalar::from(0) + &self.share + Scalar::zero() }; // Both parties add the public value to their modifier, and the MACs do not change @@ -452,7 +456,7 @@ impl Add<&ScalarResult> for &AuthenticatedScalarResult { let new_share = if self.fabric().party_id() == PARTY0 { &self.share + rhs } else { - &self.share + Scalar::from(0) + &self.share + Scalar::zero() }; let new_modifier = &self.public_modifier - rhs; @@ -1144,6 +1148,6 @@ mod tests { }) .await; - assert_eq!(res.unwrap(), 0.into()); + assert_eq!(res.unwrap(), 0u8.into()); } } diff --git a/src/algebra/curve.rs b/src/algebra/curve.rs index 9f86ae7..aa9a566 100644 --- a/src/algebra/curve.rs +++ b/src/algebra/curve.rs @@ -12,14 +12,13 @@ use ark_ec::{ map_to_curve_hasher::MapToCurve, HashToCurveError, }, - short_weierstrass::{Affine, Projective, SWCurveConfig}, - CurveConfig, CurveGroup, Group, VariableBaseMSM, + short_weierstrass::Projective, + CurveGroup, }; -use ark_ff::{BigInt, MontFp, PrimeField, Zero}; +use ark_ff::PrimeField; -use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, SerializationError}; +use ark_serialize::SerializationError; use itertools::Itertools; -use num_bigint::BigUint; use serde::{de::Error as DeError, Deserialize, Serialize}; use crate::{ @@ -54,6 +53,7 @@ pub const HASH_TO_CURVE_SECURITY: usize = 16; // 128 bit security /// A wrapper around the inner point that allows us to define foreign traits on the point #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub struct CurvePoint(pub(crate) C); +impl Unpin for CurvePoint {} impl Serialize for CurvePoint { fn serialize(&self, serializer: S) -> Result { @@ -96,22 +96,6 @@ impl CurvePoint { self.0.into_affine() } - /// Construct a `CurvePoint` from its affine coordinates - pub fn from_affine_coords(x: BigUint, y: BigUint) -> Self { - let x_bigint = BigInt::try_from(x).unwrap(); - let y_bigint = BigInt::try_from(y).unwrap(); - let x = Self::BaseField::from(x_bigint); - let y = Self::BaseField::from(y_bigint); - - let aff = Affine { - x, - y, - infinity: false, - }; - - Self(aff.into()) - } - /// The group generator pub fn generator() -> CurvePoint { CurvePoint(C::generator()) @@ -132,13 +116,24 @@ impl CurvePoint { let point = C::deserialize_compressed(bytes)?; Ok(CurvePoint(point)) } +} +impl CurvePoint +where + C::BaseField: PrimeField, +{ /// Get the number of bytes needed to represent a point, this is exactly the number of bytes /// for one base field element, as we can simply use the x-coordinate and set a high bit for the `y` pub fn n_bytes() -> usize { - n_bytes_field::() + n_bytes_field::() } +} +impl CurvePoint +where + C::Config: SWUConfig, + C::BaseField: PrimeField, +{ /// Convert a uniform byte buffer to a `CurvePoint` via the SWU map-to-curve approach: /// /// See https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-hash-to-curve-09#simple-swu @@ -146,7 +141,9 @@ impl CurvePoint { /// result of an `extend_message` implementation that gives us its uniform digest. From here /// we construct two field elements, map to curve, and add the points to give a uniformly /// distributed curve point - pub fn from_uniform_bytes(buf: Vec) -> Result, HashToCurveError> { + pub fn from_uniform_bytes( + buf: Vec, + ) -> Result>, HashToCurveError> { let n_bytes = Self::n_bytes(); assert_eq!( buf.len(), @@ -169,7 +166,7 @@ impl CurvePoint { } /// A helper that converts an arbitrarily long byte buffer to a field element - fn hash_to_field(buf: &[u8]) -> Self::BaseField { + fn hash_to_field(buf: &[u8]) -> C::BaseField { Self::BaseField::from_be_bytes_mod_order(buf) } } @@ -194,7 +191,6 @@ impl Add<&C> for &CurvePoint { } } impl_borrow_variants!(CurvePoint, Add, add, +, C, C: CurveGroup); -impl_commutative!(CurvePoint, Add, add, +, C, C: CurveGroup); impl Add<&CurvePoint> for &CurvePoint { type Output = CurvePoint; @@ -432,7 +428,7 @@ impl Mul<&ScalarResult> for &CurvePoint { fn mul(self, rhs: &ScalarResult) -> Self::Output { let self_owned = *self; rhs.fabric.new_gate_op(vec![rhs.id], move |args| { - let rhs: Scalar = args[0].to_owned().into(); + let rhs: Scalar = args[0].to_owned().into(); ResultValue::Point(CurvePoint(self_owned.0 * rhs.0)) }) } @@ -446,7 +442,7 @@ impl Mul<&ScalarResult> for &CurvePointResult { fn mul(self, rhs: &ScalarResult) -> Self::Output { self.fabric.new_gate_op(vec![self.id, rhs.id], |mut args| { let lhs: CurvePoint = args.remove(0).into(); - let rhs: Scalar = args.remove(0).into(); + let rhs: Scalar = args.remove(0).into(); ResultValue::Point(CurvePoint(lhs.0 * rhs.0)) }) @@ -620,7 +616,7 @@ impl CurvePoint { /// represented as streaming iterators pub fn msm_iter(scalars: I, points: J) -> CurvePoint where - I: IntoIterator, + I: IntoIterator>, J: IntoIterator>, { let mut res = CurvePoint::identity(); @@ -630,7 +626,7 @@ impl CurvePoint { .into_iter() .zip(points.into_iter().chunks(MSM_CHUNK_SIZE).into_iter()) { - let scalars: Vec = scalar_chunk.collect(); + let scalars: Vec> = scalar_chunk.collect(); let points: Vec> = point_chunk.collect(); let chunk_res = CurvePoint::msm(&scalars, &points); @@ -667,7 +663,7 @@ impl CurvePoint { /// as iterators. Assumes the iterators are non-empty pub fn msm_results_iter(scalars: I, points: J) -> CurvePointResult where - I: IntoIterator, + I: IntoIterator>, J: IntoIterator>, { Self::msm_results( @@ -731,10 +727,10 @@ impl CurvePoint { /// This method assumes that the iterators are of the same length pub fn msm_authenticated_iter(scalars: I, points: J) -> AuthenticatedPointResult where - I: IntoIterator, + I: IntoIterator>, J: IntoIterator>, { - let scalars: Vec = scalars.into_iter().collect(); + let scalars: Vec> = scalars.into_iter().collect(); let points: Vec> = points.into_iter().collect(); Self::msm_authenticated(&scalars, &points) @@ -777,7 +773,7 @@ impl CurvePointResult { /// Assumes the iterator is non-empty pub fn msm_results_iter(scalars: I, points: J) -> CurvePointResult where - I: IntoIterator, + I: IntoIterator>, J: IntoIterator>, { Self::msm_results( @@ -848,10 +844,10 @@ impl CurvePointResult { /// represented as streaming iterators pub fn msm_authenticated_iter(scalars: I, points: J) -> AuthenticatedPointResult where - I: IntoIterator, + I: IntoIterator>, J: IntoIterator>, { - let scalars: Vec = scalars.into_iter().collect(); + let scalars: Vec> = scalars.into_iter().collect(); let points: Vec> = points.into_iter().collect(); Self::msm_authenticated(&scalars, &points) @@ -866,11 +862,10 @@ impl CurvePointResult { /// https://github.com/xJonathanLEI/starknet-rs #[cfg(test)] mod test { - use rand::{thread_rng, RngCore}; - - use crate::algebra::test_helper::{random_point, TestCurve}; + use crate::algebra::test_helper::random_point; use super::*; + /// Test that the generators are the same between the two curve representations #[test] fn test_generators() { @@ -922,31 +917,4 @@ mod test { assert_eq!(p1, res); } - - /// Tests the hash-to-curve implementation `CurvePoint::from_uniform_bytes` - #[test] - fn test_hash_to_curve() { - // Sample random bytes into a buffer - let mut rng = thread_rng(); - const N_BYTES: usize = n_bytes_field::(); - let mut buf = [0u8; N_BYTES * 2]; - rng.fill_bytes(&mut buf); - - // As long as the method does not error, the test is successful - let res = CurvePoint::from_uniform_bytes(buf); - assert!(res.is_ok()) - } - - /// Tests converting to and from affine coordinates - #[test] - fn test_affine_conversion() { - let projective = random_point(); - let affine = projective.to_affine(); - - let x = BigUint::from(affine.x); - let y = BigUint::from(affine.y); - let recovered = CurvePoint::from_affine_coords(x, y); - - assert_eq!(projective, recovered); - } } diff --git a/src/algebra/macros.rs b/src/algebra/macros.rs index 59c46cb..f38e068 100644 --- a/src/algebra/macros.rs +++ b/src/algebra/macros.rs @@ -4,9 +4,9 @@ /// implements the same arithmetic on the owned and partially-owned variants macro_rules! impl_borrow_variants { // Single type trait - ($target:ty, $trait:ident, $fn_name:ident, $op:tt, $($gen:ident: $gen_ty:ty),*) => { + ($target:ty, $trait:ident, $fn_name:ident, $op:tt, $($gen:ident: $gen_ty:ident),*) => { // Single implementation, owned target type - impl<$($gen),*> $trait for $target { + impl<$($gen:$gen_ty),*> $trait for $target { type Output = $target; fn $fn_name(self) -> Self::Output { @@ -16,14 +16,14 @@ macro_rules! impl_borrow_variants { }; // Output type same as left hand side - ($lhs:ty, $trait:ident, $fn_name:ident, $op:tt, $rhs:ty, $($gen:ident: $gen_ty:ty),*) => { + ($lhs:ty, $trait:ident, $fn_name:ident, $op:tt, $rhs:ty, $($gen:ident: $gen_ty:ident),*) => { impl_borrow_variants!($lhs, $trait, $fn_name, $op, $rhs, Output=$lhs, $($gen: $gen_ty),*); }; // Output type specified - ($lhs:ty, $trait:ident, $fn_name:ident, $op:tt, $rhs:ty, Output=$out_type:ty, $($gen:ident: $gen_ty:ty),*) => { + ($lhs:ty, $trait:ident, $fn_name:ident, $op:tt, $rhs:ty, Output=$out_type:ty, $($gen:ident: $gen_ty:ident),*) => { /// lhs borrowed, rhs owned - impl<'a, $($gen),*> $trait<$rhs> for &'a $lhs { + impl<'a, $($gen: $gen_ty),*> $trait<$rhs> for &'a $lhs { type Output = $out_type; fn $fn_name(self, rhs: $rhs) -> Self::Output { @@ -32,7 +32,7 @@ macro_rules! impl_borrow_variants { } /// lhs owned, rhs borrowed - impl<'a, $($gen),*> $trait<&'a $rhs> for $lhs { + impl<'a, $($gen: $gen_ty),*> $trait<&'a $rhs> for $lhs { type Output = $out_type; fn $fn_name(self, rhs: &'a $rhs) -> Self::Output { @@ -41,7 +41,7 @@ macro_rules! impl_borrow_variants { } /// lhs owned, rhs owned - impl<$($gen),*> $trait<$rhs> for $lhs { + impl<$($gen: $gen_ty),*> $trait<$rhs> for $lhs { type Output = $out_type; fn $fn_name(self, rhs: $rhs) -> Self::Output { @@ -53,13 +53,13 @@ macro_rules! impl_borrow_variants { /// A macro to implement commutative variants of a binary operation macro_rules! impl_commutative { - ($lhs:ty, $trait:ident, $fn_name:ident, $op:tt, $rhs:ty, $($gen:ident: $gen_ty:ty),*) => { + ($lhs:ty, $trait:ident, $fn_name:ident, $op:tt, $rhs:ty, $($gen:ident: $gen_ty:ident),*) => { impl_commutative!($lhs, $trait, $fn_name, $op, $rhs, Output=$lhs, $($gen: $gen_ty),*); }; - ($lhs:ty, $trait:ident, $fn_name:ident, $op:tt, $rhs:ty, Output=$out_type:ty, $($gen:ident: $gen_ty:ty),*) => { + ($lhs:ty, $trait:ident, $fn_name:ident, $op:tt, $rhs:ty, Output=$out_type:ty, $($gen:ident: $gen_ty:ident),*) => { /// lhs borrowed, rhs borrowed - impl<'a, $($gen),*> $trait<&'a $lhs> for &'a $rhs { + impl<'a, $($gen: $gen_ty),*> $trait<&'a $lhs> for &'a $rhs { type Output = $out_type; fn $fn_name(self, rhs: &'a $lhs) -> Self::Output { @@ -68,7 +68,7 @@ macro_rules! impl_commutative { } /// lhs borrowed, rhs owned - impl<'a, $($gen),*> $trait<$lhs> for &'a $rhs + impl<'a, $($gen: $gen_ty),*> $trait<$lhs> for &'a $rhs { type Output = $out_type; @@ -78,7 +78,7 @@ macro_rules! impl_commutative { } /// lhs owned, rhs borrowed - impl<'a, $($gen),*> $trait<&'a $lhs> for $rhs + impl<'a, $($gen: $gen_ty),*> $trait<&'a $lhs> for $rhs { type Output = $out_type; @@ -88,7 +88,7 @@ macro_rules! impl_commutative { } /// lhs owned, rhs owned - impl<$($gen),*> $trait<$lhs> for $rhs + impl<$($gen: $gen_ty),*> $trait<$lhs> for $rhs { type Output = $out_type; diff --git a/src/algebra/mod.rs b/src/algebra/mod.rs index 71ee380..9d41d3b 100644 --- a/src/algebra/mod.rs +++ b/src/algebra/mod.rs @@ -9,14 +9,11 @@ pub mod mpc_scalar; pub mod scalar; /// Helpers useful for testing throughout the `algebra` module -#[cfg(test)] +#[cfg(any(test, feature = "test_helpers"))] pub(crate) mod test_helper { - use std::iter; - - use super::scalar::Scalar; + use super::{curve::CurvePoint, scalar::Scalar}; use ark_curve25519::EdwardsProjective as Curve25519Projective; - use ark_ec::CurveGroup; use ark_ff::PrimeField; use num_bigint::BigUint; use rand::thread_rng; @@ -27,12 +24,14 @@ pub(crate) mod test_helper { /// A curve used for testing algebra implementations, set to curve25519 pub type TestCurve = Curve25519Projective; + /// A curve point on the test curve + pub type TestCurvePoint = CurvePoint; /// Generate a random point, by multiplying the basepoint with a random scalar - pub fn random_point() -> TestCurve { + pub fn random_point() -> TestCurvePoint { let mut rng = thread_rng(); let scalar = Scalar::random(&mut rng); - let point = TestCurve::generator() * scalar; + let point = TestCurvePoint::generator() * scalar; point * scalar } diff --git a/src/algebra/mpc_curve.rs b/src/algebra/mpc_curve.rs index 093d5d4..2f24b55 100644 --- a/src/algebra/mpc_curve.rs +++ b/src/algebra/mpc_curve.rs @@ -41,14 +41,14 @@ impl MpcPointResult { } /// Borrow the fabric that this result is allocated in - pub fn fabric(&self) -> &MpcFabric { + pub fn fabric(&self) -> &MpcFabric { self.share.fabric() } /// Open the value; both parties send their shares to the counterparty pub fn open(&self) -> CurvePointResult { let send_my_share = - |args: Vec| NetworkPayload::Point(args[0].to_owned().into()); + |args: Vec>| NetworkPayload::Point(args[0].to_owned().into()); // Party zero sends first then receives let (share0, share1): (CurvePointResult, CurvePointResult) = @@ -76,7 +76,7 @@ impl MpcPointResult { let n = values.len(); let fabric = &values[0].fabric(); let all_ids = values.iter().map(|v| v.id()).collect_vec(); - let send_my_shares = |args: Vec| { + let send_my_shares = |args: Vec>| { NetworkPayload::PointBatch(args.into_iter().map(|arg| arg.into()).collect_vec()) }; @@ -444,8 +444,8 @@ impl Mul<&ScalarResult> for &MpcPointResult { fn mul(self, rhs: &ScalarResult) -> Self::Output { self.fabric() .new_gate_op(vec![self.id(), rhs.id()], |mut args| { - let lhs: CurvePoint = args.remove(0).into(); - let rhs: Scalar = args.remove(0).into(); + let lhs: CurvePoint = args.remove(0).into(); + let rhs: Scalar = args.remove(0).into(); ResultValue::Point(lhs * rhs) }) diff --git a/src/algebra/mpc_scalar.rs b/src/algebra/mpc_scalar.rs index a55477a..29aaa39 100644 --- a/src/algebra/mpc_scalar.rs +++ b/src/algebra/mpc_scalar.rs @@ -8,7 +8,7 @@ use itertools::Itertools; use crate::{ algebra::scalar::BatchScalarResult, - fabric::{MpcFabric, ResultHandle, ResultValue}, + fabric::{MpcFabric, ResultValue}, network::NetworkPayload, PARTY0, }; @@ -46,7 +46,7 @@ impl MpcScalarResult { } /// Borrow the fabric that the result is allocated in - pub fn fabric(&self) -> &MpcFabric { + pub fn fabric(&self) -> &MpcFabric { self.share.fabric() } @@ -54,17 +54,17 @@ impl MpcScalarResult { pub fn open(&self) -> ScalarResult { // Party zero sends first then receives let (val0, val1) = if self.fabric().party_id() == PARTY0 { - let party0_value: ResultHandle = + let party0_value: ScalarResult = self.fabric().new_network_op(vec![self.id()], |args| { - let share: Scalar = args[0].to_owned().into(); + let share: Scalar = args[0].to_owned().into(); NetworkPayload::Scalar(share) }); - let party1_value: ResultHandle = self.fabric().receive_value(); + let party1_value: ScalarResult = self.fabric().receive_value(); (party0_value, party1_value) } else { - let party0_value: ResultHandle = self.fabric().receive_value(); - let party1_value: ResultHandle = + let party0_value: ScalarResult = self.fabric().receive_value(); + let party1_value: ScalarResult = self.fabric().new_network_op(vec![self.id()], |args| { let share = args[0].to_owned().into(); NetworkPayload::Scalar(share) @@ -86,8 +86,8 @@ impl MpcScalarResult { let n = values.len(); let fabric = &values[0].fabric(); let my_results = values.iter().map(|v| v.id()).collect_vec(); - let send_shares_fn = |args: Vec| { - let shares: Vec = args.into_iter().map(Scalar::from).collect(); + let send_shares_fn = |args: Vec>| { + let shares: Vec> = args.into_iter().map(Scalar::from).collect(); NetworkPayload::ScalarBatch(shares) }; @@ -109,8 +109,8 @@ impl MpcScalarResult { // Create the new values by combining the additive shares fabric.new_batch_gate_op(vec![party0_vals.id, party1_vals.id], n, move |args| { - let party0_vals: Vec = args[0].to_owned().into(); - let party1_vals: Vec = args[1].to_owned().into(); + let party0_vals: Vec> = args[0].to_owned().into(); + let party1_vals: Vec> = args[1].to_owned().into(); let mut results = Vec::with_capacity(n); for i in 0..n { @@ -144,7 +144,7 @@ impl Add<&Scalar> for &MpcScalarResult { self.fabric() .new_gate_op(vec![self.id()], move |args| { // Cast the args - let lhs_share: Scalar = args[0].to_owned().into(); + let lhs_share: Scalar = args[0].to_owned().into(); if party_id == PARTY0 { ResultValue::Scalar(lhs_share + rhs) } else { @@ -166,8 +166,8 @@ impl Add<&ScalarResult> for &MpcScalarResult { self.fabric() .new_gate_op(vec![self.id(), rhs.id], move |mut args| { // Cast the args - let lhs: Scalar = args.remove(0).into(); - let rhs: Scalar = args.remove(0).into(); + let lhs: Scalar = args.remove(0).into(); + let rhs: Scalar = args.remove(0).into(); if party_id == PARTY0 { ResultValue::Scalar(lhs + rhs) @@ -188,8 +188,8 @@ impl Add<&MpcScalarResult> for &MpcScalarResult { self.fabric() .new_gate_op(vec![self.id(), rhs.id()], |args| { // Cast the args - let lhs: Scalar = args[0].to_owned().into(); - let rhs: Scalar = args[1].to_owned().into(); + let lhs: Scalar = args[0].to_owned().into(); + let rhs: Scalar = args[1].to_owned().into(); ResultValue::Scalar(lhs + rhs) }) @@ -253,11 +253,11 @@ impl MpcScalarResult { let scalars: Vec> = fabric.new_batch_gate_op(ids, n /* output_arity */, move |args| { if party_id == PARTY0 { - let mut res: Vec = Vec::with_capacity(n); + let mut res: Vec> = Vec::with_capacity(n); for i in 0..n { - let lhs: Scalar = args[i].to_owned().into(); - let rhs: Scalar = args[i + n].to_owned().into(); + let lhs: Scalar = args[i].to_owned().into(); + let rhs: Scalar = args[i + n].to_owned().into(); res.push(ResultValue::Scalar(lhs + rhs)); } @@ -353,8 +353,8 @@ impl Sub<&MpcScalarResult> for &MpcScalarResult { self.fabric() .new_gate_op(vec![self.id(), rhs.id()], |args| { // Cast the args - let lhs: Scalar = args[0].to_owned().into(); - let rhs: Scalar = args[1].to_owned().into(); + let lhs: Scalar = args[0].to_owned().into(); + let rhs: Scalar = args[1].to_owned().into(); ResultValue::Scalar(lhs - rhs) }) @@ -422,11 +422,11 @@ impl MpcScalarResult { let party_id = fabric.party_id(); let scalars = fabric.new_batch_gate_op(ids, n /* output_arity */, move |args| { if party_id == PARTY0 { - let mut res: Vec = Vec::with_capacity(n); + let mut res: Vec> = Vec::with_capacity(n); for i in 0..n { - let lhs: Scalar = args[i].to_owned().into(); - let rhs: Scalar = args[i + n].to_owned().into(); + let lhs: Scalar = args[i].to_owned().into(); + let rhs: Scalar = args[i + n].to_owned().into(); res.push(ResultValue::Scalar(lhs - rhs)); } @@ -450,7 +450,7 @@ impl Neg for &MpcScalarResult { self.fabric() .new_gate_op(vec![self.id()], |args| { // Cast the args - let lhs: Scalar = args[0].to_owned().into(); + let lhs: Scalar = args[0].to_owned().into(); ResultValue::Scalar(-lhs) }) .into() @@ -494,7 +494,7 @@ impl Mul<&Scalar> for &MpcScalarResult { self.fabric() .new_gate_op(vec![self.id()], move |args| { // Cast the args - let lhs: Scalar = args[0].to_owned().into(); + let lhs: Scalar = args[0].to_owned().into(); ResultValue::Scalar(lhs * rhs) }) .into() @@ -510,8 +510,8 @@ impl Mul<&ScalarResult> for &MpcScalarResult { self.fabric() .new_gate_op(vec![self.id(), rhs.id()], move |mut args| { // Cast the args - let lhs: Scalar = args.remove(0).into(); - let rhs: Scalar = args.remove(0).into(); + let lhs: Scalar = args.remove(0).into(); + let rhs: Scalar = args.remove(0).into(); ResultValue::Scalar(lhs * rhs) }) @@ -599,10 +599,10 @@ impl MpcScalarResult { let scalars: Vec> = fabric.new_batch_gate_op(ids, n /* output_arity */, move |args| { - let mut res: Vec = Vec::with_capacity(n); + let mut res: Vec> = Vec::with_capacity(n); for i in 0..n { - let lhs: Scalar = args[i].to_owned().into(); - let rhs: Scalar = args[i + n].to_owned().into(); + let lhs: Scalar = args[i].to_owned().into(); + let rhs: Scalar = args[i + n].to_owned().into(); res.push(ResultValue::Scalar(lhs * rhs)); } @@ -623,7 +623,7 @@ impl Mul<&MpcScalarResult> for &CurvePoint { let self_owned = *self; rhs.fabric() .new_gate_op(vec![rhs.id()], move |mut args| { - let rhs: Scalar = args.remove(0).into(); + let rhs: Scalar = args.remove(0).into(); ResultValue::Point(self_owned * rhs) }) @@ -638,8 +638,8 @@ impl Mul<&MpcScalarResult> for &CurvePointResult { fn mul(self, rhs: &MpcScalarResult) -> Self::Output { self.fabric .new_gate_op(vec![self.id(), rhs.id()], |mut args| { - let lhs: CurvePoint = args.remove(0).into(); - let rhs: Scalar = args.remove(0).into(); + let lhs: CurvePoint = args.remove(0).into(); + let rhs: Scalar = args.remove(0).into(); ResultValue::Point(lhs * rhs) }) diff --git a/src/algebra/scalar.rs b/src/algebra/scalar.rs index e925e6d..b0f0a2d 100644 --- a/src/algebra/scalar.rs +++ b/src/algebra/scalar.rs @@ -11,10 +11,10 @@ use std::{ }; use ark_ec::CurveGroup; -use ark_ff::{batch_inversion, Field, MontConfig, PrimeField}; +use ark_ff::{batch_inversion, Field, PrimeField}; use itertools::Itertools; use num_bigint::BigUint; -use rand::{CryptoRng, Rng, RngCore}; +use rand::{CryptoRng, RngCore}; use serde::{Deserialize, Serialize}; use crate::fabric::{ResultHandle, ResultValue}; @@ -30,7 +30,7 @@ use super::macros::{impl_borrow_variants, impl_commutative}; pub const fn n_bytes_field() -> usize { // We add 7 and divide by 8 to emulate a ceiling operation considering that u32 // division is a floor - let n_bits = F::MODULUS_BIT_SIZE; + let n_bits = F::MODULUS_BIT_SIZE as usize; (n_bits + 7) / 8 } @@ -48,26 +48,28 @@ impl Scalar { /// The scalar field's additive identity pub fn zero() -> Self { - Scalar(Self::Field::from(0u8)) + Scalar(C::ScalarField::from(0u8)) } /// The scalar field's multiplicative identity pub fn one() -> Self { - Scalar(Self::Field::from(1)) + Scalar(C::ScalarField::from(1u8)) } /// Get the inner value of the scalar - pub fn inner(&self) -> Self::Field { + pub fn inner(&self) -> C::ScalarField { self.0 } - /// Generate a random scalar + /// Sample a random field element /// - /// n.b. The `rand::random` method uses `ThreadRng` type which implements - /// the `CryptoRng` traits + /// TODO: Validate that this gives a uniform distribution over the field pub fn random(rng: &mut R) -> Self { - let inner: Self::Field = rng.sample(rand::distributions::Standard); - Scalar(inner) + let mut random_bytes = vec![0u8; n_bytes_field::()]; + rng.fill_bytes(&mut random_bytes); + + let val = C::ScalarField::from_random_bytes(&random_bytes).unwrap(); + Self(val) } /// Compute the multiplicative inverse of the scalar in its field @@ -87,7 +89,7 @@ impl Scalar { /// Construct a scalar from the given bytes and reduce modulo the field's modulus pub fn from_be_bytes_mod_order(bytes: &[u8]) -> Self { - let inner = Self::Field::from_be_bytes_mod_order(bytes); + let inner = C::ScalarField::from_be_bytes_mod_order(bytes); Scalar(inner) } @@ -99,7 +101,7 @@ impl Scalar { let val_biguint = self.to_biguint(); let mut bytes = val_biguint.to_bytes_be(); - let n_bytes = n_bytes_field::(); + let n_bytes = n_bytes_field::(); let mut padding = vec![0u8; n_bytes - bytes.len()]; padding.append(&mut bytes); @@ -114,7 +116,7 @@ impl Scalar { /// Convert from a `BigUint` pub fn from_biguint(val: &BigUint) -> Self { let le_bytes = val.to_bytes_le(); - let inner = Self::Field::from_le_bytes_mod_order(&le_bytes); + let inner = C::ScalarField::from_le_bytes_mod_order(&le_bytes); Scalar(inner) } } @@ -415,9 +417,45 @@ impl MulAssign for Scalar { // | Conversions | // --------------- -impl> From for Scalar { - fn from(val: T) -> Self { - Scalar(val.into()) +impl From for Scalar { + fn from(value: bool) -> Self { + Scalar(C::ScalarField::from(value)) + } +} + +impl From for Scalar { + fn from(value: u8) -> Self { + Scalar(C::ScalarField::from(value)) + } +} + +impl From for Scalar { + fn from(value: u16) -> Self { + Scalar(C::ScalarField::from(value)) + } +} + +impl From for Scalar { + fn from(value: u32) -> Self { + Scalar(C::ScalarField::from(value)) + } +} + +impl From for Scalar { + fn from(value: u64) -> Self { + Scalar(C::ScalarField::from(value)) + } +} + +impl From for Scalar { + fn from(value: u128) -> Self { + Scalar(C::ScalarField::from(value)) + } +} + +impl From for Scalar { + fn from(value: usize) -> Self { + Scalar(C::ScalarField::from(value as u64)) } } diff --git a/src/beaver.rs b/src/beaver.rs index 58957f9..7dc6efd 100644 --- a/src/beaver.rs +++ b/src/beaver.rs @@ -1,6 +1,7 @@ //! Defines the Beaver value generation interface //! as well as a dummy beaver interface for testing +use ark_ec::CurveGroup; use itertools::Itertools; use crate::algebra::scalar::Scalar; @@ -10,38 +11,42 @@ use crate::algebra::scalar::Scalar; /// x_1 and party 2 holds x_2 such that x_1 + x_2 = x /// 2. Beaver triplets; additively shared values [a], [b], [c] such /// that a * b = c -pub trait SharedValueSource: Send + Sync { +pub trait SharedValueSource: Send + Sync { /// Fetch the next shared single bit - fn next_shared_bit(&mut self) -> Scalar; + fn next_shared_bit(&mut self) -> Scalar; /// Fetch the next shared batch of bits - fn next_shared_bit_batch(&mut self, num_values: usize) -> Vec { + fn next_shared_bit_batch(&mut self, num_values: usize) -> Vec> { (0..num_values) .map(|_| self.next_shared_bit()) .collect_vec() } /// Fetch the next shared single value - fn next_shared_value(&mut self) -> Scalar; + fn next_shared_value(&mut self) -> Scalar; /// Fetch a batch of shared single values - fn next_shared_value_batch(&mut self, num_values: usize) -> Vec { + fn next_shared_value_batch(&mut self, num_values: usize) -> Vec> { (0..num_values) .map(|_| self.next_shared_value()) .collect_vec() } /// Fetch the next pair of values that are multiplicative inverses of one another - fn next_shared_inverse_pair(&mut self) -> (Scalar, Scalar); + fn next_shared_inverse_pair(&mut self) -> (Scalar, Scalar); /// Fetch the next batch of multiplicative inverse pairs - fn next_shared_inverse_pair_batch(&mut self, num_pairs: usize) -> (Vec, Vec) { + fn next_shared_inverse_pair_batch( + &mut self, + num_pairs: usize, + ) -> (Vec>, Vec>) { (0..num_pairs) .map(|_| self.next_shared_inverse_pair()) .unzip() } /// Fetch the next beaver triplet - fn next_triplet(&mut self) -> (Scalar, Scalar, Scalar); + fn next_triplet(&mut self) -> (Scalar, Scalar, Scalar); /// Fetch a batch of beaver triplets + #[allow(clippy::type_complexity)] fn next_triplet_batch( &mut self, num_triplets: usize, - ) -> (Vec, Vec, Vec) { + ) -> (Vec>, Vec>, Vec>) { let mut a_vals = Vec::with_capacity(num_triplets); let mut b_vals = Vec::with_capacity(num_triplets); let mut c_vals = Vec::with_capacity(num_triplets); @@ -76,14 +81,14 @@ impl PartyIDBeaverSource { /// The PartyIDBeaverSource returns beaver triplets split statically between the /// parties. We assume a = 2, b = 3 ==> c = 6. [a] = (1, 1); [b] = (3, 0) [c] = (2, 4) #[cfg(any(feature = "test_helpers", test))] -impl SharedValueSource for PartyIDBeaverSource { - fn next_shared_bit(&mut self) -> Scalar { +impl SharedValueSource for PartyIDBeaverSource { + fn next_shared_bit(&mut self) -> Scalar { // Simply output partyID, assume partyID \in {0, 1} assert!(self.party_id == 0 || self.party_id == 1); Scalar::from(self.party_id) } - fn next_triplet(&mut self) -> (Scalar, Scalar, Scalar) { + fn next_triplet(&mut self) -> (Scalar, Scalar, Scalar) { if self.party_id == 0 { (Scalar::from(1u64), Scalar::from(3u64), Scalar::from(2u64)) } else { @@ -91,11 +96,11 @@ impl SharedValueSource for PartyIDBeaverSource { } } - fn next_shared_inverse_pair(&mut self) -> (Scalar, Scalar) { + fn next_shared_inverse_pair(&mut self) -> (Scalar, Scalar) { (Scalar::from(self.party_id), Scalar::from(self.party_id)) } - fn next_shared_value(&mut self) -> Scalar { + fn next_shared_value(&mut self) -> Scalar { Scalar::from(self.party_id) } } diff --git a/src/commitment.rs b/src/commitment.rs index 1de2474..b96f13f 100644 --- a/src/commitment.rs +++ b/src/commitment.rs @@ -1,13 +1,14 @@ //! Defines Pedersen commitments over the Stark curve used to commit to a value //! before opening it +use ark_ec::CurveGroup; use rand::thread_rng; use sha3::{Digest, Sha3_256}; use crate::{ algebra::{ + curve::{CurvePoint, CurvePointResult}, scalar::{Scalar, ScalarResult}, - stark_curve::{StarkPoint, StarkPointResult}, }, fabric::ResultValue, }; @@ -15,19 +16,19 @@ use crate::{ /// A handle on the result of a Pedersen commitment, including the committed secret /// /// Of the form `value * G + blinder * H` -pub(crate) struct PedersenCommitment { +pub(crate) struct PedersenCommitment { /// The committed value - pub(crate) value: Scalar, + pub(crate) value: Scalar, /// The commitment blinder - pub(crate) blinder: Scalar, + pub(crate) blinder: Scalar, /// The value of the commitment - pub(crate) commitment: StarkPoint, + pub(crate) commitment: CurvePoint, } -impl PedersenCommitment { +impl PedersenCommitment { /// Verify that the given commitment is valid pub(crate) fn verify(&self) -> bool { - let generator = StarkPoint::generator(); + let generator = CurvePoint::generator(); let commitment = generator * self.value + generator * self.blinder; commitment == self.commitment @@ -35,23 +36,23 @@ impl PedersenCommitment { } /// A Pedersen commitment that has been allocated in an MPC computation graph -pub(crate) struct PedersenCommitmentResult { +pub(crate) struct PedersenCommitmentResult { /// The committed value - pub(crate) value: ScalarResult, + pub(crate) value: ScalarResult, /// The commitment blinder - pub(crate) blinder: Scalar, + pub(crate) blinder: Scalar, /// The value of the commitment - pub(crate) commitment: StarkPointResult, + pub(crate) commitment: CurvePointResult, } -impl PedersenCommitmentResult { +impl PedersenCommitmentResult { /// Create a new Pedersen commitment to an underlying value - pub(crate) fn commit(value: ScalarResult) -> PedersenCommitmentResult { + pub(crate) fn commit(value: ScalarResult) -> PedersenCommitmentResult { // Concretely, we use the curve generator for both `G` and `H` as is done // in dalek-cryptography: https://github.com/dalek-cryptography/bulletproofs/blob/main/src/generators.rs#L44-L53 let mut rng = thread_rng(); let blinder = Scalar::random(&mut rng); - let generator = StarkPoint::generator(); + let generator = CurvePoint::generator(); let commitment = generator * &value + generator * blinder; PedersenCommitmentResult { @@ -69,16 +70,16 @@ impl PedersenCommitmentResult { /// We use hash commitments to commit to curve points before opening them. There is no straightforward /// way to adapt Pedersen commitments to curve points, and we do not need the homomorphic properties /// of a Pedersen commitment -pub(crate) struct HashCommitment { +pub(crate) struct HashCommitment { /// The committed value - pub(crate) value: StarkPoint, + pub(crate) value: CurvePoint, /// The blinder used in the commitment - pub(crate) blinder: Scalar, + pub(crate) blinder: Scalar, /// The value of the commitment - pub(crate) commitment: Scalar, + pub(crate) commitment: Scalar, } -impl HashCommitment { +impl HashCommitment { /// Verify that the given commitment is valid pub(crate) fn verify(&self) -> bool { // Create the bytes buffer @@ -97,22 +98,22 @@ impl HashCommitment { } /// A hash commitment that has been allocated in an MPC computation graph -pub(crate) struct HashCommitmentResult { +pub(crate) struct HashCommitmentResult { /// The committed value - pub(crate) value: StarkPointResult, + pub(crate) value: CurvePointResult, /// The blinder used in the commitment - pub(crate) blinder: Scalar, + pub(crate) blinder: Scalar, /// The value of the commitment - pub(crate) commitment: ScalarResult, + pub(crate) commitment: ScalarResult, } -impl HashCommitmentResult { +impl HashCommitmentResult { /// Create a new hash commitment to an underlying value - pub(crate) fn commit(value: StarkPointResult) -> HashCommitmentResult { + pub(crate) fn commit(value: CurvePointResult) -> HashCommitmentResult { let mut rng = thread_rng(); let blinder = Scalar::random(&mut rng); let comm = value.fabric.new_gate_op(vec![value.id], move |mut args| { - let value: StarkPoint = args.remove(0).into(); + let value: CurvePoint = args.remove(0).into(); // Create the bytes buffer let mut bytes = value.to_bytes(); diff --git a/src/fabric.rs b/src/fabric.rs index b9a993c..1ecd054 100644 --- a/src/fabric.rs +++ b/src/fabric.rs @@ -9,6 +9,7 @@ mod executor; mod network_sender; mod result; +use ark_ec::CurveGroup; #[cfg(feature = "benchmarks")] pub use executor::{Executor, ExecutorMessage}; #[cfg(not(feature = "benchmarks"))] @@ -34,12 +35,12 @@ use itertools::Itertools; use crate::{ algebra::{ + authenticated_curve::AuthenticatedPointResult, authenticated_scalar::AuthenticatedScalarResult, - authenticated_stark_point::AuthenticatedStarkPointResult, + curve::{BatchCurvePointResult, CurvePoint, CurvePointResult}, + mpc_curve::MpcPointResult, mpc_scalar::MpcScalarResult, - mpc_stark_point::MpcStarkPointResult, scalar::{BatchScalarResult, Scalar, ScalarResult}, - stark_curve::{BatchStarkPointResult, StarkPoint, StarkPointResult}, }, beaver::SharedValueSource, network::{MpcNetwork, NetworkOutbound, NetworkPayload, PartyId}, @@ -72,7 +73,7 @@ pub type OperationId = usize; /// /// `N` represents the number of results that this operation outputs #[derive(Clone)] -pub struct Operation { +pub struct Operation { /// Identifier of the result that this operation emits id: OperationId, /// The result ID of the first result in the outputs @@ -84,52 +85,53 @@ pub struct Operation { /// The IDs of the inputs to this operation args: Vec, /// The type of the operation - op_type: OperationType, + op_type: OperationType, } -impl Operation { +impl Operation { /// Get the result IDs for an operation pub fn result_ids(&self) -> Vec { (self.result_id..self.result_id + self.output_arity).collect_vec() } } -impl Debug for Operation { +impl Debug for Operation { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { write!(f, "Operation {}", self.id) } } /// Defines the different types of operations available in the computation graph -pub enum OperationType { +pub enum OperationType { /// A gate operation; may be evaluated locally given its ready inputs Gate { /// The function to apply to the inputs - function: Box) -> ResultValue + Send + Sync>, + function: Box>) -> ResultValue + Send + Sync>, }, /// A gate operation that has output arity greater than one /// /// We separate this out to avoid vector allocation for result values of arity one GateBatch { /// The function to apply to the inputs - function: Box) -> Vec + Send + Sync>, + #[allow(clippy::type_complexity)] + function: Box>) -> Vec> + Send + Sync>, }, /// A network operation, requires that a value be sent over the network Network { /// The function to apply to the inputs to derive a Network payload - function: Box) -> NetworkPayload + Send + Sync>, + function: Box>) -> NetworkPayload + Send + Sync>, }, } /// A clone implementation, never concretely called but used as a Marker type to allow /// pre-allocating buffer space for `Operation`s -impl Clone for OperationType { +impl Clone for OperationType { fn clone(&self) -> Self { panic!("cannot clone `OperationType`") } } -impl Debug for OperationType { +impl Debug for OperationType { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { match self { OperationType::Gate { .. } => write!(f, "Gate"), @@ -147,23 +149,23 @@ impl Debug for OperationType { /// continue using the fabric, scheduling more gates to be evaluated and maximally exploiting /// gate-level parallelism within the circuit #[derive(Clone)] -pub struct MpcFabric { +pub struct MpcFabric { /// The inner fabric #[cfg(not(feature = "benchmarks"))] - inner: Arc, + inner: Arc>, /// The inner fabric, accessible publicly for benchmark mocking #[cfg(feature = "benchmarks")] - pub inner: Arc, + pub inner: Arc>, /// The local party's share of the global MAC key /// /// The parties collectively hold an additive sharing of the global key /// /// We wrap in a reference counting structure to avoid recursive type issues #[cfg(not(feature = "benchmarks"))] - mac_key: Option>, + mac_key: Option>>, /// The MAC key, accessible publicly for benchmark mocking #[cfg(feature = "benchmarks")] - pub mac_key: Option>, + pub mac_key: Option>>, /// The channel on which shutdown messages are sent to blocking workers #[cfg(not(feature = "benchmarks"))] shutdown: BroadcastSender<()>, @@ -172,7 +174,7 @@ pub struct MpcFabric { pub shutdown: BroadcastSender<()>, } -impl Debug for MpcFabric { +impl Debug for MpcFabric { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { write!(f, "MpcFabric") } @@ -181,7 +183,7 @@ impl Debug for MpcFabric { /// The inner component of the fabric, allows the constructor to allocate executor and network /// sender objects at the same level as the fabric #[derive(Clone)] -pub struct FabricInner { +pub struct FabricInner { /// The ID of the local party in the MPC execution party_id: u64, /// The next identifier to assign to a result @@ -189,32 +191,32 @@ pub struct FabricInner { /// The next identifier to assign to an operation next_op_id: Arc, /// A sender to the executor - execution_queue: Arc>, + execution_queue: Arc>>, /// The underlying queue to the network - outbound_queue: KanalSender, + outbound_queue: KanalSender>, /// The underlying shared randomness source - beaver_source: Arc>>, + beaver_source: Arc>>>, } -impl Debug for FabricInner { +impl Debug for FabricInner { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { write!(f, "FabricInner") } } -impl FabricInner { +impl FabricInner { /// Constructor - pub fn new( + pub fn new>( party_id: u64, - execution_queue: Arc>, - outbound_queue: KanalSender, + execution_queue: Arc>>, + outbound_queue: KanalSender>, beaver_source: S, ) -> Self { // Allocate a zero and a one as well as the curve identity in the fabric to begin, // for convenience let zero = ResultValue::Scalar(Scalar::zero()); let one = ResultValue::Scalar(Scalar::one()); - let identity = ResultValue::Point(StarkPoint::identity()); + let identity = ResultValue::Point(CurvePoint::identity()); for initial_result in vec![ OpResult { @@ -249,7 +251,7 @@ impl FabricInner { } /// Register a waiter on a result - pub(crate) fn register_waiter(&self, waiter: ResultWaiter) { + pub(crate) fn register_waiter(&self, waiter: ResultWaiter) { self.execution_queue .push(ExecutorMessage::NewWaiter(waiter)); } @@ -293,7 +295,7 @@ impl FabricInner { // ------------------------ /// Allocate a new plaintext value in the fabric - pub(crate) fn allocate_value(&self, value: ResultValue) -> ResultId { + pub(crate) fn allocate_value(&self, value: ResultValue) -> ResultId { // Forward the result to the executor let id = self.new_result_id(); self.execution_queue @@ -305,8 +307,8 @@ impl FabricInner { /// Allocate a secret shared value in the network pub(crate) fn allocate_shared_value( &self, - my_share: ResultValue, - their_share: ResultValue, + my_share: ResultValue, + their_share: ResultValue, ) -> ResultId { // Forward the local party's share to the executor let id = self.new_result_id(); @@ -343,7 +345,7 @@ impl FabricInner { &self, args: Vec, output_arity: usize, - op_type: OperationType, + op_type: OperationType, ) -> Vec { if matches!(op_type, OperationType::Gate { .. }) { assert_eq!(output_arity, 1, "gate operations must have arity 1"); @@ -370,9 +372,9 @@ impl FabricInner { } } -impl MpcFabric { +impl MpcFabric { /// Constructor - pub fn new( + pub fn new, S: 'static + SharedValueSource>( network: N, beaver_source: S, ) -> Self { @@ -381,7 +383,7 @@ impl MpcFabric { /// Constructor that takes an additional size hint, indicating how much buffer space /// the fabric should allocate for results. The size is given in number of gates - pub fn new_with_size_hint( + pub fn new_with_size_hint, S: 'static + SharedValueSource>( size_hint: usize, network: N, beaver_source: S, @@ -450,12 +452,12 @@ impl MpcFabric { } /// Register a waiter on a result - pub fn register_waiter(&self, waiter: ResultWaiter) { + pub fn register_waiter(&self, waiter: ResultWaiter) { self.inner.register_waiter(waiter); } /// Immutably borrow the MAC key - pub(crate) fn borrow_mac_key(&self) -> &MpcScalarResult { + pub(crate) fn borrow_mac_key(&self) -> &MpcScalarResult { // Unwrap is safe, the constructor sets the MAC key self.mac_key.as_ref().unwrap() } @@ -465,19 +467,19 @@ impl MpcFabric { // ------------------------ /// Get the hardcoded zero wire as a raw `ScalarResult` - pub fn zero(&self) -> ScalarResult { + pub fn zero(&self) -> ScalarResult { ResultHandle::new(self.inner.zero(), self.clone()) } /// Get the shared zero value as an `MpcScalarResult` - fn zero_shared(&self) -> MpcScalarResult { + fn zero_shared(&self) -> MpcScalarResult { MpcScalarResult::new_shared(self.zero()) } /// Get the hardcoded zero wire as an `AuthenticatedScalarResult` /// /// Both parties hold the share 0 directly in this case - pub fn zero_authenticated(&self) -> AuthenticatedScalarResult { + pub fn zero_authenticated(&self) -> AuthenticatedScalarResult { let zero_value = self.zero(); let share_value = self.zero_shared(); let mac_value = self.zero_shared(); @@ -490,25 +492,25 @@ impl MpcFabric { } /// Get a batch of references to the zero wire as an `AuthenticatedScalarResult` - pub fn zeros_authenticated(&self, n: usize) -> Vec { + pub fn zeros_authenticated(&self, n: usize) -> Vec> { let val = self.zero_authenticated(); (0..n).map(|_| val.clone()).collect_vec() } /// Get the hardcoded one wire as a raw `ScalarResult` - pub fn one(&self) -> ScalarResult { + pub fn one(&self) -> ScalarResult { ResultHandle::new(self.inner.one(), self.clone()) } /// Get the hardcoded shared one wire as an `MpcScalarResult` - fn one_shared(&self) -> MpcScalarResult { + fn one_shared(&self) -> MpcScalarResult { MpcScalarResult::new_shared(self.one()) } /// Get the hardcoded one wire as an `AuthenticatedScalarResult` /// /// Party 0 holds the value zero and party 1 holds the value one - pub fn one_authenticated(&self) -> AuthenticatedScalarResult { + pub fn one_authenticated(&self) -> AuthenticatedScalarResult { if self.party_id() == PARTY0 { let zero_value = self.zero(); let share_value = self.zero_shared(); @@ -533,30 +535,30 @@ impl MpcFabric { } /// Get a batch of references to the one wire as an `AuthenticatedScalarResult` - pub fn ones_authenticated(&self, n: usize) -> Vec { + pub fn ones_authenticated(&self, n: usize) -> Vec> { let val = self.one_authenticated(); (0..n).map(|_| val.clone()).collect_vec() } /// Get the hardcoded curve identity wire as a raw `StarkPoint` - pub fn curve_identity(&self) -> ResultHandle { + pub fn curve_identity(&self) -> CurvePointResult { ResultHandle::new(self.inner.curve_identity(), self.clone()) } /// Get the hardcoded shared curve identity wire as an `MpcStarkPointResult` - fn curve_identity_shared(&self) -> MpcStarkPointResult { - MpcStarkPointResult::new_shared(self.curve_identity()) + fn curve_identity_shared(&self) -> MpcPointResult { + MpcPointResult::new_shared(self.curve_identity()) } /// Get the hardcoded curve identity wire as an `AuthenticatedStarkPointResult` /// /// Both parties hold the identity point directly in this case - pub fn curve_identity_authenticated(&self) -> AuthenticatedStarkPointResult { + pub fn curve_identity_authenticated(&self) -> AuthenticatedPointResult { let identity_val = self.curve_identity(); let share_value = self.curve_identity_shared(); let mac_value = self.curve_identity_shared(); - AuthenticatedStarkPointResult { + AuthenticatedPointResult { share: share_value, mac: mac_value, public_modifier: identity_val, @@ -568,22 +570,22 @@ impl MpcFabric { // ------------------- /// Allocate a shared value in the fabric - fn allocate_shared_value>( + fn allocate_shared_value>>( &self, - my_share: ResultValue, - their_share: ResultValue, - ) -> ResultHandle { + my_share: ResultValue, + their_share: ResultValue, + ) -> ResultHandle { let id = self.inner.allocate_shared_value(my_share, their_share); ResultHandle::new(id, self.clone()) } /// Share a `Scalar` value with the counterparty - pub fn share_scalar>( + pub fn share_scalar>>( &self, val: T, sender: PartyId, - ) -> AuthenticatedScalarResult { - let scalar: ScalarResult = if self.party_id() == sender { + ) -> AuthenticatedScalarResult { + let scalar: ScalarResult = if self.party_id() == sender { let scalar_val = val.into(); let mut rng = thread_rng(); let random = Scalar::random(&mut rng); @@ -601,13 +603,13 @@ impl MpcFabric { } /// Share a batch of `Scalar` values with the counterparty - pub fn batch_share_scalar>( + pub fn batch_share_scalar>>( &self, vals: Vec, sender: PartyId, - ) -> Vec { + ) -> Vec> { let n = vals.len(); - let shares: BatchScalarResult = if self.party_id() == sender { + let shares: BatchScalarResult = if self.party_id() == sender { let vals = vals.into_iter().map(|val| val.into()).collect_vec(); let mut rng = thread_rng(); @@ -632,8 +634,8 @@ impl MpcFabric { } /// Share a `StarkPoint` value with the counterparty - pub fn share_point(&self, val: StarkPoint, sender: PartyId) -> AuthenticatedStarkPointResult { - let point: StarkPointResult = if self.party_id() == sender { + pub fn share_point(&self, val: CurvePoint, sender: PartyId) -> AuthenticatedPointResult { + let point: CurvePointResult = if self.party_id() == sender { // As mentioned in https://eprint.iacr.org/2009/226.pdf // it is okay to sample a random point by sampling a random `Scalar` and multiplying // by the generator in the case that the discrete log of the output may be leaked with @@ -641,7 +643,7 @@ impl MpcFabric { // when it is used to generate secret shares let mut rng = thread_rng(); let random = Scalar::random(&mut rng); - let random_point = random * StarkPoint::generator(); + let random_point = random * CurvePoint::generator(); let (my_share, their_share) = (val - random_point, random_point); self.allocate_shared_value( @@ -652,19 +654,19 @@ impl MpcFabric { self.receive_value() }; - AuthenticatedStarkPointResult::new_shared(point) + AuthenticatedPointResult::new_shared(point) } /// Share a batch of `StarkPoint`s with the counterparty pub fn batch_share_point( &self, - vals: Vec, + vals: Vec>, sender: PartyId, - ) -> Vec { + ) -> Vec> { let n = vals.len(); - let shares: BatchStarkPointResult = if self.party_id() == sender { + let shares: BatchCurvePointResult = if self.party_id() == sender { let mut rng = thread_rng(); - let generator = StarkPoint::generator(); + let generator = CurvePoint::generator(); let peer_shares = (0..vals.len()) .map(|_| { let discrete_log = Scalar::random(&mut rng); @@ -685,17 +687,17 @@ impl MpcFabric { self.receive_value() }; - AuthenticatedStarkPointResult::new_shared_from_batch_result(shares, n) + AuthenticatedPointResult::new_shared_from_batch_result(shares, n) } /// Allocate a public value in the fabric - pub fn allocate_scalar>(&self, value: T) -> ResultHandle { + pub fn allocate_scalar>>(&self, value: T) -> ScalarResult { let id = self.inner.allocate_value(ResultValue::Scalar(value.into())); ResultHandle::new(id, self.clone()) } /// Allocate a batch of scalars in the fabric - pub fn allocate_scalars>(&self, values: Vec) -> Vec> { + pub fn allocate_scalars>>(&self, values: Vec) -> Vec> { values .into_iter() .map(|value| self.allocate_scalar(value)) @@ -703,31 +705,31 @@ impl MpcFabric { } /// Allocate a scalar as a secret share of an already shared value - pub fn allocate_preshared_scalar>( + pub fn allocate_preshared_scalar>>( &self, value: T, - ) -> AuthenticatedScalarResult { + ) -> AuthenticatedScalarResult { let allocated = self.allocate_scalar(value); AuthenticatedScalarResult::new_shared(allocated) } /// Allocate a batch of scalars as secret shares of already shared values - pub fn batch_allocate_preshared_scalar>( + pub fn batch_allocate_preshared_scalar>>( &self, values: Vec, - ) -> Vec { + ) -> Vec> { let values = self.allocate_scalars(values); AuthenticatedScalarResult::new_shared_batch(&values) } /// Allocate a public curve point in the fabric - pub fn allocate_point(&self, value: StarkPoint) -> ResultHandle { + pub fn allocate_point(&self, value: CurvePoint) -> CurvePointResult { let id = self.inner.allocate_value(ResultValue::Point(value)); ResultHandle::new(id, self.clone()) } /// Allocate a batch of points in the fabric - pub fn allocate_points(&self, values: Vec) -> Vec> { + pub fn allocate_points(&self, values: Vec>) -> Vec> { values .into_iter() .map(|value| self.allocate_point(value)) @@ -735,18 +737,18 @@ impl MpcFabric { } /// Send a value to the peer, placing the identity in the local result buffer at the send ID - pub fn send_value + Into>( + pub fn send_value> + Into>>( &self, - value: ResultHandle, - ) -> ResultHandle { + value: ResultHandle, + ) -> ResultHandle { self.new_network_op(vec![value.id], |mut args| args.remove(0).into()) } /// Send a batch of values to the counterparty - pub fn send_values(&self, values: &[ResultHandle]) -> ResultHandle> + pub fn send_values(&self, values: &[ResultHandle]) -> ResultHandle> where - T: From, - Vec: Into + From, + T: From>, + Vec: Into> + From>, { let ids = values.iter().map(|v| v.id).collect_vec(); self.new_network_op(ids, |args| { @@ -756,7 +758,7 @@ impl MpcFabric { } /// Receive a value from the peer - pub fn receive_value>(&self) -> ResultHandle { + pub fn receive_value>>(&self) -> ResultHandle { let id = self.inner.receive_value(); ResultHandle::new(id, self.clone()) } @@ -765,10 +767,10 @@ impl MpcFabric { /// based on the party ID /// /// Returns a handle to the received value, which will be different for different parties - pub fn exchange_value + Into>( + pub fn exchange_value> + Into>>( &self, - value: ResultHandle, - ) -> ResultHandle { + value: ResultHandle, + ) -> ResultHandle { if self.party_id() == PARTY0 { // Party 0 sends first then receives self.send_value(value); @@ -783,10 +785,10 @@ impl MpcFabric { /// Exchange a batch of values with the peer, i.e. send then receive or receive then send /// based on party ID - pub fn exchange_values(&self, values: &[ResultHandle]) -> ResultHandle> + pub fn exchange_values(&self, values: &[ResultHandle]) -> ResultHandle> where - T: From, - Vec: From + Into, + T: From>, + Vec: From> + Into>, { if self.party_id() == PARTY0 { self.send_values(values); @@ -799,9 +801,9 @@ impl MpcFabric { } /// Share a public value with the counterparty - pub fn share_plaintext(&self, value: T, sender: PartyId) -> ResultHandle + pub fn share_plaintext(&self, value: T, sender: PartyId) -> ResultHandle where - T: 'static + From + Into + Send + Sync, + T: 'static + From> + Into> + Send + Sync, { if self.party_id() == sender { self.new_network_op(vec![], move |_args| value.into()) @@ -811,10 +813,14 @@ impl MpcFabric { } /// Share a batch of public values with the counterparty - pub fn batch_share_plaintext(&self, values: Vec, sender: PartyId) -> ResultHandle> + pub fn batch_share_plaintext( + &self, + values: Vec, + sender: PartyId, + ) -> ResultHandle> where - T: 'static + From + Send + Sync, - Vec: Into + From, + T: 'static + From> + Send + Sync, + Vec: Into> + From>, { self.share_plaintext(values, sender) } @@ -825,10 +831,10 @@ impl MpcFabric { /// Construct a new gate operation in the fabric, i.e. one that can be evaluated immediate given /// its inputs - pub fn new_gate_op(&self, args: Vec, function: F) -> ResultHandle + pub fn new_gate_op(&self, args: Vec, function: F) -> ResultHandle where - F: 'static + FnOnce(Vec) -> ResultValue + Send + Sync, - T: From, + F: 'static + FnOnce(Vec>) -> ResultValue + Send + Sync, + T: From>, { let function = Box::new(function); let id = self.inner.new_op( @@ -849,10 +855,10 @@ impl MpcFabric { args: Vec, output_arity: usize, function: F, - ) -> Vec> + ) -> Vec> where - F: 'static + FnOnce(Vec) -> Vec + Send + Sync, - T: From, + F: 'static + FnOnce(Vec>) -> Vec> + Send + Sync, + T: From>, { let function = Box::new(function); let ids = self @@ -865,10 +871,10 @@ impl MpcFabric { /// Construct a new network operation in the fabric, i.e. one that requires a value to be sent /// over the channel - pub fn new_network_op(&self, args: Vec, function: F) -> ResultHandle + pub fn new_network_op(&self, args: Vec, function: F) -> ResultHandle where - F: 'static + FnOnce(Vec) -> NetworkPayload + Send + Sync, - T: From, + F: 'static + FnOnce(Vec>) -> NetworkPayload + Send + Sync, + T: From>, { let function = Box::new(function); let id = self.inner.new_op( @@ -884,7 +890,9 @@ impl MpcFabric { // ----------------- /// Sample the next beaver triplet from the beaver source - pub fn next_beaver_triple(&self) -> (MpcScalarResult, MpcScalarResult, MpcScalarResult) { + pub fn next_beaver_triple( + &self, + ) -> (MpcScalarResult, MpcScalarResult, MpcScalarResult) { // Sample the triple and allocate it in the fabric, the counterparty will do the same let (a, b, c) = self .inner @@ -905,13 +913,14 @@ impl MpcFabric { } /// Sample a batch of beaver triples + #[allow(clippy::type_complexity)] pub fn next_beaver_triple_batch( &self, n: usize, ) -> ( - Vec, - Vec, - Vec, + Vec>, + Vec>, + Vec>, ) { let (a_vals, b_vals, c_vals) = self .inner @@ -946,9 +955,9 @@ impl MpcFabric { pub fn next_authenticated_triple( &self, ) -> ( - AuthenticatedScalarResult, - AuthenticatedScalarResult, - AuthenticatedScalarResult, + AuthenticatedScalarResult, + AuthenticatedScalarResult, + AuthenticatedScalarResult, ) { let (a, b, c) = self .inner @@ -969,13 +978,14 @@ impl MpcFabric { } /// Sample the next batch of beaver triples as `AuthenticatedScalar`s + #[allow(clippy::type_complexity)] pub fn next_authenticated_triple_batch( &self, n: usize, ) -> ( - Vec, - Vec, - Vec, + Vec>, + Vec>, + Vec>, ) { let (a_vals, b_vals, c_vals) = self .inner @@ -996,7 +1006,7 @@ impl MpcFabric { } /// Sample a batch of random shared values from the beaver source - pub fn random_shared_scalars(&self, n: usize) -> Vec { + pub fn random_shared_scalars(&self, n: usize) -> Vec> { let values_raw = self .inner .beaver_source @@ -1012,7 +1022,10 @@ impl MpcFabric { } /// Sample a batch of random shared values from the beaver source and allocate them as `AuthenticatedScalars` - pub fn random_shared_scalars_authenticated(&self, n: usize) -> Vec { + pub fn random_shared_scalars_authenticated( + &self, + n: usize, + ) -> Vec> { let values_raw = self .inner .beaver_source @@ -1031,7 +1044,9 @@ impl MpcFabric { } /// Sample a pair of values that are multiplicative inverses of one another - pub fn random_inverse_pair(&self) -> (AuthenticatedScalarResult, AuthenticatedScalarResult) { + pub fn random_inverse_pair( + &self, + ) -> (AuthenticatedScalarResult, AuthenticatedScalarResult) { let (l, r) = self .inner .beaver_source @@ -1049,8 +1064,8 @@ impl MpcFabric { &self, n: usize, ) -> ( - Vec, - Vec, + Vec>, + Vec>, ) { let (left, right) = self .inner @@ -1070,7 +1085,7 @@ impl MpcFabric { } /// Sample a random shared bit from the beaver source - pub fn random_shared_bit(&self) -> AuthenticatedScalarResult { + pub fn random_shared_bit(&self) -> AuthenticatedScalarResult { let bit = self .inner .beaver_source @@ -1083,7 +1098,7 @@ impl MpcFabric { } /// Sample a batch of random shared bits from the beaver source - pub fn random_shared_bits(&self, n: usize) -> Vec { + pub fn random_shared_bits(&self, n: usize) -> Vec> { let bits = self .inner .beaver_source diff --git a/src/fabric/executor.rs b/src/fabric/executor.rs index 831aeca..b660fb4 100644 --- a/src/fabric/executor.rs +++ b/src/fabric/executor.rs @@ -2,9 +2,10 @@ //! them, and places the result back into the fabric for further executions use std::collections::HashMap; -use std::fmt::{Debug, Formatter, Result as FmtResult}; +use std::fmt::Debug; use std::sync::Arc; +use ark_ec::CurveGroup; use crossbeam::queue::SegQueue; use itertools::Itertools; use tracing::log; @@ -107,19 +108,19 @@ impl Debug for ExecutorStats { /// The executor is responsible for executing operation that are ready for execution, either /// passed explicitly by the fabric or as a result of a dependency being satisfied -pub struct Executor { +pub struct Executor { /// The job queue for the executor - job_queue: Arc>, + job_queue: Arc>>, /// The operation buffer, stores in-flight operations - operations: GrowableBuffer, + operations: GrowableBuffer>, /// The dependency map; maps in-flight results to operations that are waiting for them dependencies: GrowableBuffer>, /// The completed results of operations - results: GrowableBuffer, + results: GrowableBuffer>, /// An index of waiters for incomplete results - waiters: HashMap>, + waiters: HashMap>>, /// The underlying fabric that the executor is a part of - fabric: FabricInner, + fabric: FabricInner, /// The collected statistics of the executor #[cfg(feature = "stats")] stats: ExecutorStats, @@ -132,23 +133,23 @@ pub struct Executor { /// arguments are ready /// - A new waiter for a result, which the executor will add to its waiter map #[derive(Debug)] -pub enum ExecutorMessage { +pub enum ExecutorMessage { /// A result of an operation - Result(OpResult), + Result(OpResult), /// An operation that is ready for execution - Op(Operation), + Op(Operation), /// A new waiter has registered itself for a result - NewWaiter(ResultWaiter), + NewWaiter(ResultWaiter), /// Indicates that the executor should shut down Shutdown, } -impl Executor { +impl Executor { /// Constructor pub fn new( circuit_size_hint: usize, - job_queue: Arc>, - fabric: FabricInner, + job_queue: Arc>>, + fabric: FabricInner, ) -> Self { #[cfg(feature = "stats")] { @@ -202,7 +203,7 @@ impl Executor { } /// Handle a new result - fn handle_new_result(&mut self, result: OpResult) { + fn handle_new_result(&mut self, result: OpResult) { let id = result.id; let prev = self.results.insert(result.id, result); assert!(prev.is_none(), "duplicate result id: {id:?}"); @@ -232,7 +233,7 @@ impl Executor { } /// Handle a new operation - fn handle_new_operation(&mut self, mut op: Operation) { + fn handle_new_operation(&mut self, mut op: Operation) { #[cfg(feature = "stats")] { self.record_op_depth(&op); @@ -279,7 +280,7 @@ impl Executor { } /// Executes an operation whose arguments are ready - fn execute_operation(&mut self, op: Operation) { + fn execute_operation(&mut self, op: Operation) { let result_ids = op.result_ids(); // Collect the inputs to the operation @@ -332,7 +333,7 @@ impl Executor { } /// Handle a new waiter for a result - pub fn handle_new_waiter(&mut self, waiter: ResultWaiter) { + pub fn handle_new_waiter(&mut self, waiter: ResultWaiter) { let id = waiter.result_id; // Insert the new waiter to the queue diff --git a/src/fabric/network_sender.rs b/src/fabric/network_sender.rs index f2d7487..d2d9311 100644 --- a/src/fabric/network_sender.rs +++ b/src/fabric/network_sender.rs @@ -5,6 +5,7 @@ use std::fmt::Debug; use std::sync::atomic::AtomicUsize; use std::sync::Arc; +use ark_ec::CurveGroup; use crossbeam::queue::SegQueue; use futures::stream::SplitSink; use futures::SinkExt; @@ -71,22 +72,22 @@ impl NetworkStats { /// The network sender sits behind the scheduler and is responsible for forwarding messages /// onto the network and pulling results off the network, re-enqueuing them for processing -pub(crate) struct NetworkSender { +pub(crate) struct NetworkSender> { /// The outbound queue of messages to send - outbound: KanalReceiver, + outbound: KanalReceiver>, /// The queue of completed results - result_queue: Arc>, + result_queue: Arc>>, /// The underlying network connection network: N, /// The broadcast channel on which shutdown signals are sent shutdown: BroadcastReceiver<()>, } -impl NetworkSender { +impl + 'static> NetworkSender { /// Creates a new network sender pub fn new( - outbound: KanalReceiver, - result_queue: Arc>, + outbound: KanalReceiver>, + result_queue: Arc>>, network: N, shutdown: BroadcastReceiver<()>, ) -> Self { @@ -112,7 +113,7 @@ impl NetworkSender { let stats = Arc::new(NetworkStats::default()); // Start a read and write loop separately - let (send, recv): (SplitSink, SplitStream) = network.split(); + let (send, recv): (SplitSink>, SplitStream) = network.split(); let read_loop_fut = tokio::spawn(Self::read_loop(recv, result_queue, stats.clone())); let write_loop_fut = tokio::spawn(Self::write_loop(outbound, send, stats.clone())); @@ -138,7 +139,7 @@ impl NetworkSender { /// with the executor async fn read_loop( mut network_stream: SplitStream, - result_queue: Arc>, + result_queue: Arc>>, stats: Arc, ) -> MpcNetworkError { while let Some(Ok(msg)) = network_stream.next().await { @@ -161,8 +162,8 @@ impl NetworkSender { /// The write loop for the network, reads messages from the outbound queue and sends them /// onto the network async fn write_loop( - outbound_stream: KanalReceiver, - mut network: SplitSink, + outbound_stream: KanalReceiver>, + mut network: SplitSink>, stats: Arc, ) -> MpcNetworkError { while let Ok(msg) = outbound_stream.recv().await { diff --git a/src/fabric/result.rs b/src/fabric/result.rs index 50f64f9..8d027bf 100644 --- a/src/fabric/result.rs +++ b/src/fabric/result.rs @@ -10,10 +10,11 @@ use std::{ task::{Context, Poll, Waker}, }; +use ark_ec::CurveGroup; use futures::Future; use crate::{ - algebra::{scalar::Scalar, stark_curve::StarkPoint}, + algebra::{curve::CurvePoint, scalar::Scalar}, network::NetworkPayload, Shared, }; @@ -32,29 +33,29 @@ pub type ResultId = usize; /// The result of an MPC operation #[derive(Clone, Debug)] -pub struct OpResult { +pub struct OpResult { /// The ID of the result's output pub id: ResultId, /// The result's value - pub value: ResultValue, + pub value: ResultValue, } /// The value of a result #[derive(Clone)] -pub enum ResultValue { +pub enum ResultValue { /// A byte value Bytes(Vec), /// A scalar value - Scalar(Scalar), + Scalar(Scalar), /// A batch of scalars - ScalarBatch(Vec), + ScalarBatch(Vec>), /// A point on the curve - Point(StarkPoint), + Point(CurvePoint), /// A batch of points on the curve - PointBatch(Vec), + PointBatch(Vec>), } -impl Debug for ResultValue { +impl Debug for ResultValue { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { match self { ResultValue::Bytes(bytes) => f.debug_tuple("Bytes").field(bytes).finish(), @@ -68,8 +69,8 @@ impl Debug for ResultValue { } } -impl From for ResultValue { - fn from(value: NetworkPayload) -> Self { +impl From> for ResultValue { + fn from(value: NetworkPayload) -> Self { match value { NetworkPayload::Bytes(bytes) => ResultValue::Bytes(bytes), NetworkPayload::Scalar(scalar) => ResultValue::Scalar(scalar), @@ -80,8 +81,8 @@ impl From for ResultValue { } } -impl From for NetworkPayload { - fn from(value: ResultValue) -> Self { +impl From> for NetworkPayload { + fn from(value: ResultValue) -> Self { match value { ResultValue::Bytes(bytes) => NetworkPayload::Bytes(bytes), ResultValue::Scalar(scalar) => NetworkPayload::Scalar(scalar), @@ -93,8 +94,8 @@ impl From for NetworkPayload { } // -- Coercive Casts to Concrete Types -- // -impl From for Vec { - fn from(value: ResultValue) -> Self { +impl From> for Vec { + fn from(value: ResultValue) -> Self { match value { ResultValue::Bytes(bytes) => bytes, _ => panic!("Cannot cast {:?} to bytes", value), @@ -102,8 +103,8 @@ impl From for Vec { } } -impl From for Scalar { - fn from(value: ResultValue) -> Self { +impl From> for Scalar { + fn from(value: ResultValue) -> Self { match value { ResultValue::Scalar(scalar) => scalar, _ => panic!("Cannot cast {:?} to scalar", value), @@ -111,8 +112,8 @@ impl From for Scalar { } } -impl From<&ResultValue> for Scalar { - fn from(value: &ResultValue) -> Self { +impl From<&ResultValue> for Scalar { + fn from(value: &ResultValue) -> Self { match value { ResultValue::Scalar(scalar) => *scalar, _ => panic!("Cannot cast {:?} to scalar", value), @@ -120,8 +121,8 @@ impl From<&ResultValue> for Scalar { } } -impl From for Vec { - fn from(value: ResultValue) -> Self { +impl From> for Vec> { + fn from(value: ResultValue) -> Self { match value { ResultValue::ScalarBatch(scalars) => scalars, _ => panic!("Cannot cast {:?} to scalar batch", value), @@ -129,8 +130,8 @@ impl From for Vec { } } -impl From for StarkPoint { - fn from(value: ResultValue) -> Self { +impl From> for CurvePoint { + fn from(value: ResultValue) -> Self { match value { ResultValue::Point(point) => point, _ => panic!("Cannot cast {:?} to point", value), @@ -138,8 +139,8 @@ impl From for StarkPoint { } } -impl From<&ResultValue> for StarkPoint { - fn from(value: &ResultValue) -> Self { +impl From<&ResultValue> for CurvePoint { + fn from(value: &ResultValue) -> Self { match value { ResultValue::Point(point) => *point, _ => panic!("Cannot cast {:?} to point", value), @@ -147,8 +148,8 @@ impl From<&ResultValue> for StarkPoint { } } -impl From for Vec { - fn from(value: ResultValue) -> Self { +impl From> for Vec> { + fn from(value: ResultValue) -> Self { match value { ResultValue::PointBatch(points) => points, _ => panic!("Cannot cast {:?} to point batch", value), @@ -168,32 +169,32 @@ impl From for Vec { /// This allows for construction of the graph concurrently with execution, giving the /// fabric the opportunity to schedule all results onto the network optimistically #[derive(Clone, Debug)] -pub struct ResultHandle> { +pub struct ResultHandle>> { /// The id of the result pub(crate) id: ResultId, /// The buffer that the result will be written to when it becomes available - pub(crate) result_buffer: Shared>, + pub(crate) result_buffer: Shared>>, /// The underlying fabric - pub(crate) fabric: MpcFabric, + pub(crate) fabric: MpcFabric, /// A phantom for the type of the result phantom: PhantomData, } -impl> ResultHandle { +impl>> ResultHandle { /// Get the id of the result pub fn id(&self) -> ResultId { self.id } /// Borrow the fabric that this result is allocated within - pub fn fabric(&self) -> &MpcFabric { + pub fn fabric(&self) -> &MpcFabric { &self.fabric } } -impl> ResultHandle { +impl>> ResultHandle { /// Constructor - pub(crate) fn new(id: ResultId, fabric: MpcFabric) -> Self { + pub(crate) fn new(id: ResultId, fabric: MpcFabric) -> Self { Self { id, result_buffer: Arc::new(RwLock::new(None)), @@ -209,16 +210,16 @@ impl> ResultHandle { } /// A struct describing an async task that is waiting on a result -pub struct ResultWaiter { +pub struct ResultWaiter { /// The id of the result that the task is waiting on pub result_id: ResultId, /// The buffer that the result will be written to when it becomes available - pub result_buffer: Shared>, + pub result_buffer: Shared>>, /// The waker of the task pub waker: Waker, } -impl Debug for ResultWaiter { +impl Debug for ResultWaiter { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { f.debug_struct("ResultWaiter") .field("id", &self.result_id) @@ -226,7 +227,7 @@ impl Debug for ResultWaiter { } } -impl + Debug> Future for ResultHandle { +impl> + Debug> Future for ResultHandle { type Output = T; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { diff --git a/src/lib.rs b/src/lib.rs index 12d82ea..8e522d3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,16 +7,11 @@ //! Defines an MPC implementation over the Stark curve that allows for out-of-order execution of //! the underlying MPC circuit -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, RwLock}, -}; +use std::sync::{Arc, RwLock}; -use algebra::{scalar::Scalar, stark_curve::StarkPoint}; -use beaver::SharedValueSource; +use algebra::{curve::CurvePoint, scalar::Scalar}; +use ark_ec::CurveGroup; -use network::MpcNetwork; use rand::thread_rng; pub mod algebra; @@ -45,9 +40,9 @@ pub const PARTY1: u64 = 1; /// Generate a random curve point by multiplying a random scalar with the /// Stark curve group generator -pub fn random_point() -> StarkPoint { +pub fn random_point() -> CurvePoint { let mut rng = thread_rng(); - StarkPoint::generator() * Scalar::random(&mut rng) + CurvePoint::generator() * Scalar::random(&mut rng) } // -------------------- @@ -57,27 +52,20 @@ pub fn random_point() -> StarkPoint { /// A type alias for a shared locked value type Shared = Arc>; -/// SharedNetwork wraps a network implementation in a borrow-safe container -/// while providing interior mutability -#[allow(type_alias_bounds)] -pub type SharedNetwork = Rc>; -/// A type alias for a shared, mutable reference to an underlying beaver source -#[allow(type_alias_bounds)] -pub type BeaverSource = Rc>; - #[cfg(any(test, feature = "test_helpers"))] pub mod test_helpers { //! Defines test helpers for use in unit and integration tests, as well as benchmarks use futures::Future; use crate::{ + algebra::test_helper::TestCurve, beaver::PartyIDBeaverSource, network::{MockNetwork, NoRecvNetwork, UnboundedDuplexStream}, MpcFabric, PARTY0, PARTY1, }; /// Create a mock fabric - pub fn mock_fabric() -> MpcFabric { + pub fn mock_fabric() -> MpcFabric { let network = NoRecvNetwork::default(); let beaver_source = PartyIDBeaverSource::default(); @@ -93,7 +81,7 @@ pub mod test_helpers { where T: Send + 'static, S: Future + Send + 'static, - F: FnMut(MpcFabric) -> S, + F: FnMut(MpcFabric) -> S, { // Build a duplex stream to broker communication between the two parties let (party0_stream, party1_stream) = UnboundedDuplexStream::new_duplex_pair(); diff --git a/src/network.rs b/src/network.rs index d829b3b..3ead927 100644 --- a/src/network.rs +++ b/src/network.rs @@ -6,6 +6,7 @@ mod mock; mod quic; mod stream_buffer; +use ark_ec::CurveGroup; pub use quic::*; use futures::{Sink, Stream}; @@ -16,7 +17,7 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use crate::{ - algebra::{scalar::Scalar, stark_curve::StarkPoint}, + algebra::{curve::CurvePoint, scalar::Scalar}, error::MpcNetworkError, fabric::ResultId, }; @@ -30,54 +31,60 @@ pub type PartyId = u64; /// The type that the network sender receives #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct NetworkOutbound { +#[serde(bound = "C: CurveGroup")] +pub struct NetworkOutbound { /// The operation ID that generated this message pub result_id: ResultId, /// The body of the message - pub payload: NetworkPayload, + pub payload: NetworkPayload, } /// The payload of an outbound message #[derive(Clone, Debug, Serialize, Deserialize)] -pub enum NetworkPayload { +#[serde(bound(serialize = "C: CurveGroup", deserialize = "C: CurveGroup"))] +pub enum NetworkPayload { /// A byte value Bytes(Vec), /// A scalar value - Scalar(Scalar), + Scalar(Scalar), /// A batch of scalar values - ScalarBatch(Vec), + ScalarBatch(Vec>), /// A point on the curve - Point(StarkPoint), + Point(CurvePoint), /// A batch of points on the curve - PointBatch(Vec), + PointBatch(Vec>), } -impl From> for NetworkPayload { +// --------------- +// | Conversions | +// --------------- + +impl From> for NetworkPayload { fn from(bytes: Vec) -> Self { Self::Bytes(bytes) } } -impl From for NetworkPayload { - fn from(scalar: Scalar) -> Self { +impl From> for NetworkPayload { + fn from(scalar: Scalar) -> Self { Self::Scalar(scalar) } } -impl From> for NetworkPayload { - fn from(scalars: Vec) -> Self { +impl From>> for NetworkPayload { + fn from(scalars: Vec>) -> Self { Self::ScalarBatch(scalars) } } -impl From for NetworkPayload { - fn from(point: StarkPoint) -> Self { +impl From> for NetworkPayload { + fn from(point: CurvePoint) -> Self { Self::Point(point) } } -impl From> for NetworkPayload { - fn from(value: Vec) -> Self { +impl From>> for NetworkPayload { + fn from(value: Vec>) -> Self { Self::PointBatch(value) } } @@ -88,17 +95,13 @@ impl From> for NetworkPayload { /// Values are sent as bytes, scalars, or curve points and always in batch form with the /// message length (measured in the number of elements sent) prepended to the message #[async_trait] -pub trait MpcNetwork: +pub trait MpcNetwork: Send - + Stream> - + Sink + + Stream, MpcNetworkError>> + + Sink, Error = MpcNetworkError> { /// Get the party ID of the local party in the MPC fn party_id(&self) -> PartyId; /// Closes the connections opened in the handshake phase async fn close(&mut self) -> Result<(), MpcNetworkError>; } - -// ----------- -// | Helpers | -// ----------- diff --git a/src/network/mock.rs b/src/network/mock.rs index 36ef940..fb11cf1 100644 --- a/src/network/mock.rs +++ b/src/network/mock.rs @@ -1,10 +1,12 @@ //! Defines a mock network for unit tests use std::{ + marker::PhantomData, pin::Pin, task::{Context, Poll}, }; +use ark_ec::CurveGroup; use async_trait::async_trait; use futures::{future::pending, Future, Sink, Stream}; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; @@ -15,10 +17,10 @@ use super::{MpcNetwork, NetworkOutbound, PartyId}; /// A dummy MPC network that never receives messages #[derive(Default)] -pub struct NoRecvNetwork; +pub struct NoRecvNetwork(PhantomData); #[async_trait] -impl MpcNetwork for NoRecvNetwork { +impl MpcNetwork for NoRecvNetwork { fn party_id(&self) -> PartyId { PARTY0 } @@ -28,22 +30,22 @@ impl MpcNetwork for NoRecvNetwork { } } -impl Stream for NoRecvNetwork { - type Item = Result; +impl Stream for NoRecvNetwork { + type Item = Result, MpcNetworkError>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Box::pin(pending()).as_mut().poll(cx) } } -impl Sink for NoRecvNetwork { +impl Sink> for NoRecvNetwork { type Error = MpcNetworkError; fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } - fn start_send(self: Pin<&mut Self>, _item: NetworkOutbound) -> Result<(), Self::Error> { + fn start_send(self: Pin<&mut Self>, _item: NetworkOutbound) -> Result<(), Self::Error> { Ok(()) } @@ -57,14 +59,14 @@ impl Sink for NoRecvNetwork { } /// A dummy MPC network that operates over a duplex channel instead of a network connection/// An unbounded duplex channel used to mock a network connection -pub struct UnboundedDuplexStream { +pub struct UnboundedDuplexStream { /// The send side of the stream - send: UnboundedSender, + send: UnboundedSender>, /// The receive side of the stream - recv: UnboundedReceiver, + recv: UnboundedReceiver>, } -impl UnboundedDuplexStream { +impl UnboundedDuplexStream { /// Create a new pair of duplex streams pub fn new_duplex_pair() -> (Self, Self) { let (send1, recv1) = unbounded_channel(); @@ -83,27 +85,27 @@ impl UnboundedDuplexStream { } /// Send a message on the stream - pub fn send(&mut self, msg: NetworkOutbound) { + pub fn send(&mut self, msg: NetworkOutbound) { self.send.send(msg).unwrap(); } /// Recv a message from the stream - pub async fn recv(&mut self) -> NetworkOutbound { + pub async fn recv(&mut self) -> NetworkOutbound { self.recv.recv().await.unwrap() } } /// A dummy network implementation used for unit testing -pub struct MockNetwork { +pub struct MockNetwork { /// The ID of the local party party_id: PartyId, /// The underlying mock network connection - mock_conn: UnboundedDuplexStream, + mock_conn: UnboundedDuplexStream, } -impl MockNetwork { +impl MockNetwork { /// Create a new mock network from one half of a duplex stream - pub fn new(party_id: PartyId, stream: UnboundedDuplexStream) -> Self { + pub fn new(party_id: PartyId, stream: UnboundedDuplexStream) -> Self { Self { party_id, mock_conn: stream, @@ -112,7 +114,7 @@ impl MockNetwork { } #[async_trait] -impl MpcNetwork for MockNetwork { +impl MpcNetwork for MockNetwork { fn party_id(&self) -> PartyId { self.party_id } @@ -122,8 +124,8 @@ impl MpcNetwork for MockNetwork { } } -impl Stream for MockNetwork { - type Item = Result; +impl Stream for MockNetwork { + type Item = Result, MpcNetworkError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Box::pin(self.mock_conn.recv()) @@ -133,14 +135,14 @@ impl Stream for MockNetwork { } } -impl Sink for MockNetwork { +impl Sink> for MockNetwork { type Error = MpcNetworkError; fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } - fn start_send(mut self: Pin<&mut Self>, item: NetworkOutbound) -> Result<(), Self::Error> { + fn start_send(mut self: Pin<&mut Self>, item: NetworkOutbound) -> Result<(), Self::Error> { self.mock_conn.send(item); Ok(()) } diff --git a/src/network/quic.rs b/src/network/quic.rs index 8e306bd..71f4742 100644 --- a/src/network/quic.rs +++ b/src/network/quic.rs @@ -1,9 +1,11 @@ //! Defines the central implementation of an `MpcNetwork` over the QUIC transport +use ark_ec::CurveGroup; use async_trait::async_trait; use futures::{Future, Sink, Stream}; use quinn::{Endpoint, RecvStream, SendStream}; use std::{ + marker::PhantomData, net::SocketAddr, pin::Pin, task::{Context, Poll}, @@ -36,7 +38,7 @@ const ERR_SEND_BUFFER_FULL: &str = "send buffer full"; // ----------------------- /// Implements an MpcNetwork on top of QUIC -pub struct QuicTwoPartyNet { +pub struct QuicTwoPartyNet { /// The index of the local party in the participants party_id: PartyId, /// Whether the network has been bootstrapped yet @@ -64,10 +66,12 @@ pub struct QuicTwoPartyNet { send_stream: Option, /// The receive side of the bidirectional stream recv_stream: Option, + /// The phantom on the curve group + _phantom: PhantomData, } #[allow(clippy::redundant_closure)] // For readability of error handling -impl<'a> QuicTwoPartyNet { +impl<'a, C: CurveGroup> QuicTwoPartyNet { /// Create a new network, do not connect the network yet pub fn new(party_id: PartyId, local_addr: SocketAddr, peer_addr: SocketAddr) -> Self { // Construct the QUIC net @@ -81,12 +85,13 @@ impl<'a> QuicTwoPartyNet { buffered_outbound: None, send_stream: None, recv_stream: None, + _phantom: PhantomData, } } /// Returns true if the local party is party 0 fn local_party0(&self) -> bool { - self.party_id() == PARTY0 + self.party_id == PARTY0 } /// Returns an error if the network is not connected @@ -226,7 +231,7 @@ impl<'a> QuicTwoPartyNet { } /// Receive a message from the peer - async fn receive_message(&mut self) -> Result { + async fn receive_message(&mut self) -> Result, MpcNetworkError> { // Read the message length from the buffer if available if self.buffered_message_length.is_none() { self.buffered_message_length = Some(self.read_message_length().await?); @@ -246,7 +251,10 @@ impl<'a> QuicTwoPartyNet { } #[async_trait] -impl MpcNetwork for QuicTwoPartyNet { +impl MpcNetwork for QuicTwoPartyNet +where + C: Unpin, +{ fn party_id(&self) -> PartyId { self.party_id } @@ -263,18 +271,27 @@ impl MpcNetwork for QuicTwoPartyNet { } } -impl Stream for QuicTwoPartyNet { - type Item = Result; +impl Stream for QuicTwoPartyNet +where + C: Unpin, +{ + type Item = Result, MpcNetworkError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Box::pin(self.receive_message()).as_mut().poll(cx).map(Some) + Box::pin(self.get_mut().receive_message()) + .as_mut() + .poll(cx) + .map(Some) } } -impl Sink for QuicTwoPartyNet { +impl Sink> for QuicTwoPartyNet +where + C: Unpin, +{ type Error = MpcNetworkError; - fn start_send(mut self: Pin<&mut Self>, msg: NetworkOutbound) -> Result<(), Self::Error> { + fn start_send(mut self: Pin<&mut Self>, msg: NetworkOutbound) -> Result<(), Self::Error> { if !self.connected { return Err(MpcNetworkError::NetworkUninitialized); } @@ -290,7 +307,7 @@ impl Sink for QuicTwoPartyNet { let mut payload = (bytes.len() as u64).to_le_bytes().to_vec(); payload.extend_from_slice(&bytes); - self.buffered_outbound = Some(BufferWithCursor::new(payload)); + self.get_mut().buffered_outbound = Some(BufferWithCursor::new(payload)); Ok(()) }