Skip to content

Commit

Permalink
network, fabric, algebra: Make entire crate generic over curve choice
Browse files Browse the repository at this point in the history
  • Loading branch information
joeykraut committed Oct 10, 2023
1 parent 6cd97f9 commit 470bbfa
Show file tree
Hide file tree
Showing 19 changed files with 519 additions and 471 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ path = "src/lib.rs"
[features]
benchmarks = []
stats = ["benchmarks"]
test_helpers = []
test_helpers = ["ark-curve25519"]

[[test]]
name = "integration"
Expand Down Expand Up @@ -73,6 +73,7 @@ kanal = "0.1.0-pre8"
tokio = { version = "1.12", features = ["macros", "rt-multi-thread"] }

# == Arithemtic + Crypto == #
ark-curve25519 = { version = "0.4", optional = true }
ark-ec = { version = "0.4", features = ["parallel"] }
ark-ff = "0.4"
ark-serialize = "0.4"
Expand Down
22 changes: 13 additions & 9 deletions src/algebra/authenticated_curve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,13 @@ impl<C: CurveGroup> AuthenticatedPointResult<C> {
n: usize,
) -> Vec<AuthenticatedPointResult<C>> {
// Convert to a set of scalar results
let scalar_results = values
.fabric()
.new_batch_gate_op(vec![values.id()], n, |mut args| {
let args: Vec<CurvePoint<C>> = args.pop().unwrap().into();
args.into_iter().map(ResultValue::Point).collect_vec()
});
let scalar_results: Vec<CurvePointResult<C>> =
values
.fabric()
.new_batch_gate_op(vec![values.id()], n, |mut args| {
let args: Vec<CurvePoint<C>> = args.pop().unwrap().into();
args.into_iter().map(ResultValue::Point).collect_vec()
});

Self::new_shared_batch(&scalar_results)
}
Expand All @@ -137,7 +138,7 @@ impl<C: CurveGroup> AuthenticatedPointResult<C> {
}

/// Borrow the fabric that this result is allocated in
pub fn fabric(&self) -> &MpcFabric {
pub fn fabric(&self) -> &MpcFabric<C> {
self.share.fabric()
}

Expand Down Expand Up @@ -400,15 +401,18 @@ impl<C: CurveGroup> Debug for AuthenticatedPointOpenResult<C> {
}
}

impl<C: CurveGroup> Future for AuthenticatedPointOpenResult<C> {
impl<C: CurveGroup> Future for AuthenticatedPointOpenResult<C>
where
C::ScalarField: Unpin,
{
type Output = Result<CurvePoint<C>, MpcError>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// Await both of the underlying values
let value = futures::ready!(self.as_mut().value.poll_unpin(cx));
let mac_check = futures::ready!(self.as_mut().mac_check.poll_unpin(cx));

if mac_check == Scalar::from(1) {
if mac_check == Scalar::from(1u8) {
Poll::Ready(Ok(value))
} else {
Poll::Ready(Err(MpcError::AuthenticationError))
Expand Down
34 changes: 19 additions & 15 deletions src/algebra/authenticated_scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::{
commitment::{PedersenCommitment, PedersenCommitmentResult},
error::MpcError,
fabric::{MpcFabric, ResultId, ResultValue},
ResultHandle, PARTY0,
PARTY0,
};

use super::{
Expand All @@ -34,7 +34,7 @@ pub const AUTHENTICATED_SCALAR_RESULT_LEN: usize = 3;
/// SPDZ protocol: https://eprint.iacr.org/2011/535.pdf
/// that ensures security against a malicious adversary
#[derive(Clone)]
pub struct AuthenticatedScalarResult<C> {
pub struct AuthenticatedScalarResult<C: CurveGroup> {
/// The secret shares of the underlying value
pub(crate) share: MpcScalarResult<C>,
/// The SPDZ style, unconditionally secure MAC of the value
Expand Down Expand Up @@ -119,12 +119,13 @@ impl<C: CurveGroup> AuthenticatedScalarResult<C> {
n: usize,
) -> Vec<AuthenticatedScalarResult<C>> {
// Convert to a set of scalar results
let scalar_results = values
.fabric()
.new_batch_gate_op(vec![values.id()], n, |mut args| {
let scalars: Vec<Scalar<C>> = args.pop().unwrap().into();
scalars.into_iter().map(ResultValue::Scalar).collect()
});
let scalar_results: Vec<ScalarResult<C>> =
values
.fabric()
.new_batch_gate_op(vec![values.id()], n, |mut args| {
let scalars: Vec<Scalar<C>> = args.pop().unwrap().into();
scalars.into_iter().map(ResultValue::Scalar).collect()
});

Self::new_shared_batch(&scalar_results)
}
Expand All @@ -141,7 +142,7 @@ impl<C: CurveGroup> AuthenticatedScalarResult<C> {
}

/// Get a reference to the underlying MPC fabric
pub fn fabric(&self) -> &MpcFabric {
pub fn fabric(&self) -> &MpcFabric<C> {
self.share.fabric()
}

Expand Down Expand Up @@ -198,7 +199,7 @@ impl<C: CurveGroup> AuthenticatedScalarResult<C> {
}

// Sum of the commitments should be zero
if peer_mac_share + my_mac_share != Scalar::from(0) {
if peer_mac_share + my_mac_share != Scalar::zero() {
return false;
}

Expand Down Expand Up @@ -396,15 +397,18 @@ pub struct AuthenticatedScalarOpenResult<C: CurveGroup> {
pub mac_check: ScalarResult<C>,
}

impl<C: CurveGroup> Future for AuthenticatedScalarOpenResult<C> {
impl<C: CurveGroup> Future for AuthenticatedScalarOpenResult<C>
where
C::ScalarField: Unpin,
{
type Output = Result<Scalar<C>, MpcError>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// Await both of the underlying values
let value = futures::ready!(self.as_mut().value.poll_unpin(cx));
let mac_check = futures::ready!(self.as_mut().mac_check.poll_unpin(cx));

if mac_check == Scalar::from(1) {
if mac_check == Scalar::from(1u8) {
Poll::Ready(Ok(value))
} else {
Poll::Ready(Err(MpcError::AuthenticationError))
Expand All @@ -425,7 +429,7 @@ impl<C: CurveGroup> Add<&Scalar<C>> for &AuthenticatedScalarResult<C> {
let new_share = if self.fabric().party_id() == PARTY0 {
&self.share + rhs
} else {
&self.share + Scalar::from(0)
&self.share + Scalar::zero()
};

// Both parties add the public value to their modifier, and the MACs do not change
Expand All @@ -452,7 +456,7 @@ impl<C: CurveGroup> Add<&ScalarResult<C>> for &AuthenticatedScalarResult<C> {
let new_share = if self.fabric().party_id() == PARTY0 {
&self.share + rhs
} else {
&self.share + Scalar::from(0)
&self.share + Scalar::zero()
};

let new_modifier = &self.public_modifier - rhs;
Expand Down Expand Up @@ -1144,6 +1148,6 @@ mod tests {
})
.await;

assert_eq!(res.unwrap(), 0.into());
assert_eq!(res.unwrap(), 0u8.into());
}
}
98 changes: 33 additions & 65 deletions src/algebra/curve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@ use ark_ec::{
map_to_curve_hasher::MapToCurve,
HashToCurveError,
},
short_weierstrass::{Affine, Projective, SWCurveConfig},
CurveConfig, CurveGroup, Group, VariableBaseMSM,
short_weierstrass::Projective,
CurveGroup,
};
use ark_ff::{BigInt, MontFp, PrimeField, Zero};
use ark_ff::PrimeField;

use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, SerializationError};
use ark_serialize::SerializationError;
use itertools::Itertools;
use num_bigint::BigUint;
use serde::{de::Error as DeError, Deserialize, Serialize};

use crate::{
Expand Down Expand Up @@ -54,6 +53,7 @@ pub const HASH_TO_CURVE_SECURITY: usize = 16; // 128 bit security
/// A wrapper around the inner point that allows us to define foreign traits on the point
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct CurvePoint<C: CurveGroup>(pub(crate) C);
impl<C: CurveGroup> Unpin for CurvePoint<C> {}

impl<C: CurveGroup> Serialize for CurvePoint<C> {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
Expand Down Expand Up @@ -96,22 +96,6 @@ impl<C: CurveGroup> CurvePoint<C> {
self.0.into_affine()
}

/// Construct a `CurvePoint<C>` from its affine coordinates
pub fn from_affine_coords(x: BigUint, y: BigUint) -> Self {
let x_bigint = BigInt::try_from(x).unwrap();
let y_bigint = BigInt::try_from(y).unwrap();
let x = Self::BaseField::from(x_bigint);
let y = Self::BaseField::from(y_bigint);

let aff = Affine {
x,
y,
infinity: false,
};

Self(aff.into())
}

/// The group generator
pub fn generator() -> CurvePoint<C> {
CurvePoint(C::generator())
Expand All @@ -132,21 +116,34 @@ impl<C: CurveGroup> CurvePoint<C> {
let point = C::deserialize_compressed(bytes)?;
Ok(CurvePoint(point))
}
}

impl<C: CurveGroup> CurvePoint<C>
where
C::BaseField: PrimeField,
{
/// Get the number of bytes needed to represent a point, this is exactly the number of bytes
/// for one base field element, as we can simply use the x-coordinate and set a high bit for the `y`
pub fn n_bytes() -> usize {
n_bytes_field::<Self::BaseField>()
n_bytes_field::<C::BaseField>()
}
}

impl<C: CurveGroup> CurvePoint<C>
where
C::Config: SWUConfig,
C::BaseField: PrimeField,
{
/// Convert a uniform byte buffer to a `CurvePoint<C>` via the SWU map-to-curve approach:
///
/// See https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-hash-to-curve-09#simple-swu
/// for a description of the setup. Essentially, we assume that the buffer provided is the
/// result of an `extend_message` implementation that gives us its uniform digest. From here
/// we construct two field elements, map to curve, and add the points to give a uniformly
/// distributed curve point
pub fn from_uniform_bytes(buf: Vec<u8>) -> Result<CurvePoint<C>, HashToCurveError> {
pub fn from_uniform_bytes(
buf: Vec<u8>,
) -> Result<CurvePoint<Projective<C::Config>>, HashToCurveError> {
let n_bytes = Self::n_bytes();
assert_eq!(
buf.len(),
Expand All @@ -169,7 +166,7 @@ impl<C: CurveGroup> CurvePoint<C> {
}

/// A helper that converts an arbitrarily long byte buffer to a field element
fn hash_to_field(buf: &[u8]) -> Self::BaseField {
fn hash_to_field(buf: &[u8]) -> C::BaseField {
Self::BaseField::from_be_bytes_mod_order(buf)
}
}
Expand All @@ -194,7 +191,6 @@ impl<C: CurveGroup> Add<&C> for &CurvePoint<C> {
}
}
impl_borrow_variants!(CurvePoint<C>, Add, add, +, C, C: CurveGroup);
impl_commutative!(CurvePoint<C>, Add, add, +, C, C: CurveGroup);

impl<C: CurveGroup> Add<&CurvePoint<C>> for &CurvePoint<C> {
type Output = CurvePoint<C>;
Expand Down Expand Up @@ -432,7 +428,7 @@ impl<C: CurveGroup> Mul<&ScalarResult<C>> for &CurvePoint<C> {
fn mul(self, rhs: &ScalarResult<C>) -> Self::Output {
let self_owned = *self;
rhs.fabric.new_gate_op(vec![rhs.id], move |args| {
let rhs: Scalar = args[0].to_owned().into();
let rhs: Scalar<C> = args[0].to_owned().into();
ResultValue::Point(CurvePoint(self_owned.0 * rhs.0))
})
}
Expand All @@ -446,7 +442,7 @@ impl<C: CurveGroup> Mul<&ScalarResult<C>> for &CurvePointResult<C> {
fn mul(self, rhs: &ScalarResult<C>) -> Self::Output {
self.fabric.new_gate_op(vec![self.id, rhs.id], |mut args| {
let lhs: CurvePoint<C> = args.remove(0).into();
let rhs: Scalar = args.remove(0).into();
let rhs: Scalar<C> = args.remove(0).into();

ResultValue::Point(CurvePoint(lhs.0 * rhs.0))
})
Expand Down Expand Up @@ -620,7 +616,7 @@ impl<C: CurveGroup> CurvePoint<C> {
/// represented as streaming iterators
pub fn msm_iter<I, J>(scalars: I, points: J) -> CurvePoint<C>
where
I: IntoIterator<Item = Scalar>,
I: IntoIterator<Item = Scalar<C>>,
J: IntoIterator<Item = CurvePoint<C>>,
{
let mut res = CurvePoint::identity();
Expand All @@ -630,7 +626,7 @@ impl<C: CurveGroup> CurvePoint<C> {
.into_iter()
.zip(points.into_iter().chunks(MSM_CHUNK_SIZE).into_iter())
{
let scalars: Vec<Scalar> = scalar_chunk.collect();
let scalars: Vec<Scalar<C>> = scalar_chunk.collect();
let points: Vec<CurvePoint<C>> = point_chunk.collect();
let chunk_res = CurvePoint::msm(&scalars, &points);

Expand Down Expand Up @@ -667,7 +663,7 @@ impl<C: CurveGroup> CurvePoint<C> {
/// as iterators. Assumes the iterators are non-empty
pub fn msm_results_iter<I, J>(scalars: I, points: J) -> CurvePointResult<C>
where
I: IntoIterator<Item = ScalarResult>,
I: IntoIterator<Item = ScalarResult<C>>,
J: IntoIterator<Item = CurvePoint<C>>,
{
Self::msm_results(
Expand Down Expand Up @@ -731,10 +727,10 @@ impl<C: CurveGroup> CurvePoint<C> {
/// This method assumes that the iterators are of the same length
pub fn msm_authenticated_iter<I, J>(scalars: I, points: J) -> AuthenticatedPointResult<C>
where
I: IntoIterator<Item = AuthenticatedScalarResult>,
I: IntoIterator<Item = AuthenticatedScalarResult<C>>,
J: IntoIterator<Item = CurvePoint<C>>,
{
let scalars: Vec<AuthenticatedScalarResult> = scalars.into_iter().collect();
let scalars: Vec<AuthenticatedScalarResult<C>> = scalars.into_iter().collect();
let points: Vec<CurvePoint<C>> = points.into_iter().collect();

Self::msm_authenticated(&scalars, &points)
Expand Down Expand Up @@ -777,7 +773,7 @@ impl<C: CurveGroup> CurvePointResult<C> {
/// Assumes the iterator is non-empty
pub fn msm_results_iter<I, J>(scalars: I, points: J) -> CurvePointResult<C>
where
I: IntoIterator<Item = ScalarResult>,
I: IntoIterator<Item = ScalarResult<C>>,
J: IntoIterator<Item = CurvePointResult<C>>,
{
Self::msm_results(
Expand Down Expand Up @@ -848,10 +844,10 @@ impl<C: CurveGroup> CurvePointResult<C> {
/// represented as streaming iterators
pub fn msm_authenticated_iter<I, J>(scalars: I, points: J) -> AuthenticatedPointResult<C>
where
I: IntoIterator<Item = AuthenticatedScalarResult>,
I: IntoIterator<Item = AuthenticatedScalarResult<C>>,
J: IntoIterator<Item = CurvePointResult<C>>,
{
let scalars: Vec<AuthenticatedScalarResult> = scalars.into_iter().collect();
let scalars: Vec<AuthenticatedScalarResult<C>> = scalars.into_iter().collect();
let points: Vec<CurvePointResult<C>> = points.into_iter().collect();

Self::msm_authenticated(&scalars, &points)
Expand All @@ -866,11 +862,10 @@ impl<C: CurveGroup> CurvePointResult<C> {
/// https://github.com/xJonathanLEI/starknet-rs
#[cfg(test)]
mod test {
use rand::{thread_rng, RngCore};

use crate::algebra::test_helper::{random_point, TestCurve};
use crate::algebra::test_helper::random_point;

use super::*;

/// Test that the generators are the same between the two curve representations
#[test]
fn test_generators() {
Expand Down Expand Up @@ -922,31 +917,4 @@ mod test {

assert_eq!(p1, res);
}

/// Tests the hash-to-curve implementation `CurvePoint<C>::from_uniform_bytes`
#[test]
fn test_hash_to_curve() {
// Sample random bytes into a buffer
let mut rng = thread_rng();
const N_BYTES: usize = n_bytes_field::<TestCurve::BaseField>();
let mut buf = [0u8; N_BYTES * 2];
rng.fill_bytes(&mut buf);

// As long as the method does not error, the test is successful
let res = CurvePoint::from_uniform_bytes(buf);
assert!(res.is_ok())
}

/// Tests converting to and from affine coordinates
#[test]
fn test_affine_conversion() {
let projective = random_point();
let affine = projective.to_affine();

let x = BigUint::from(affine.x);
let y = BigUint::from(affine.y);
let recovered = CurvePoint::from_affine_coords(x, y);

assert_eq!(projective, recovered);
}
}
Loading

0 comments on commit 470bbfa

Please sign in to comment.