diff --git a/Cargo.toml b/Cargo.toml index 8431b33..d7c5430 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -77,6 +77,7 @@ ark-curve25519 = { version = "0.4", optional = true } ark-ec = { version = "0.4", features = ["parallel"] } ark-ff = "0.4" ark-serialize = "0.4" +ark-std = "0.4" digest = "0.10" num-bigint = "0.4" rand = "0.8" diff --git a/integration/authenticated_stark_point.rs b/integration/authenticated_curve.rs similarity index 91% rename from integration/authenticated_stark_point.rs rename to integration/authenticated_curve.rs index 37536f9..fbae544 100644 --- a/integration/authenticated_stark_point.rs +++ b/integration/authenticated_curve.rs @@ -3,9 +3,9 @@ use itertools::Itertools; use mpc_stark::{ algebra::{ - authenticated_stark_point::{ + authenticated_curve::{ test_helpers::{modify_mac, modify_public_modifier, modify_share}, - AuthenticatedStarkPointResult, + AuthenticatedPointResult, }, scalar::Scalar, }, @@ -144,10 +144,9 @@ fn test_batch_add(test_args: &IntegrationTestArgs) -> Result<(), String> { let party0_values = share_authenticated_point_batch(my_vals.clone(), PARTY0, test_args); let party1_values = share_authenticated_point_batch(my_vals, PARTY1, test_args); - let res = AuthenticatedStarkPointResult::batch_add(&party0_values, &party1_values); - let res_open = await_batch_result_with_error( - AuthenticatedStarkPointResult::open_authenticated_batch(&res), - )?; + let res = AuthenticatedPointResult::batch_add(&party0_values, &party1_values); + let res_open = + await_batch_result_with_error(AuthenticatedPointResult::open_authenticated_batch(&res))?; assert_point_batches_eq(res_open, expected_result) } @@ -172,10 +171,9 @@ fn test_batch_add_public(test_args: &IntegrationTestArgs) -> Result<(), String> // Add the points in the MPC circuit let party0_values = share_authenticated_point_batch(my_vals, PARTY0, test_args); - let res = AuthenticatedStarkPointResult::batch_add_public(&party0_values, &plaintext_values); - let res_open = await_batch_result_with_error( - AuthenticatedStarkPointResult::open_authenticated_batch(&res), - )?; + let res = AuthenticatedPointResult::batch_add_public(&party0_values, &plaintext_values); + let res_open = + await_batch_result_with_error(AuthenticatedPointResult::open_authenticated_batch(&res))?; assert_point_batches_eq(res_open, expected_result) } @@ -240,10 +238,9 @@ fn test_batch_sub(test_args: &IntegrationTestArgs) -> Result<(), String> { let party0_values = share_authenticated_point_batch(my_vals.clone(), PARTY0, test_args); let party1_values = share_authenticated_point_batch(my_vals, PARTY1, test_args); - let res = AuthenticatedStarkPointResult::batch_sub(&party0_values, &party1_values); - let res_open = await_batch_result_with_error( - AuthenticatedStarkPointResult::open_authenticated_batch(&res), - )?; + let res = AuthenticatedPointResult::batch_sub(&party0_values, &party1_values); + let res_open = + await_batch_result_with_error(AuthenticatedPointResult::open_authenticated_batch(&res))?; assert_point_batches_eq(res_open, expected_result) } @@ -268,10 +265,9 @@ fn test_batch_sub_public(test_args: &IntegrationTestArgs) -> Result<(), String> // Add the points in the MPC circuit let party0_values = share_authenticated_point_batch(my_vals, PARTY0, test_args); - let res = AuthenticatedStarkPointResult::batch_sub_public(&party0_values, &plaintext_values); - let res_open = await_batch_result_with_error( - AuthenticatedStarkPointResult::open_authenticated_batch(&res), - )?; + let res = AuthenticatedPointResult::batch_sub_public(&party0_values, &plaintext_values); + let res_open = + await_batch_result_with_error(AuthenticatedPointResult::open_authenticated_batch(&res))?; assert_point_batches_eq(res_open, expected_result) } @@ -310,10 +306,9 @@ fn test_batch_negation(test_args: &IntegrationTestArgs) -> Result<(), String> { // Compute the expected result in an MPC circuit let party0_values = share_authenticated_point_batch(my_values, PARTY0, test_args); - let res = AuthenticatedStarkPointResult::batch_neg(&party0_values); - let res_open = await_batch_result_with_error( - AuthenticatedStarkPointResult::open_authenticated_batch(&res), - )?; + let res = AuthenticatedPointResult::batch_neg(&party0_values); + let res_open = + await_batch_result_with_error(AuthenticatedPointResult::open_authenticated_batch(&res))?; assert_point_batches_eq(expected_res, res_open) } @@ -382,10 +377,9 @@ fn test_batch_mul(test_args: &IntegrationTestArgs) -> Result<(), String> { let party0_values = share_authenticated_point_batch(my_vals.clone(), PARTY0, test_args); let party1_values = share_authenticated_point_batch(my_vals, PARTY1, test_args); - let res = AuthenticatedStarkPointResult::batch_sub(&party0_values, &party1_values); - let res_open = await_batch_result_with_error( - AuthenticatedStarkPointResult::open_authenticated_batch(&res), - )?; + let res = AuthenticatedPointResult::batch_sub(&party0_values, &party1_values); + let res_open = + await_batch_result_with_error(AuthenticatedPointResult::open_authenticated_batch(&res))?; assert_point_batches_eq(res_open, expected_result) } @@ -410,10 +404,9 @@ fn test_batch_mul_public(test_args: &IntegrationTestArgs) -> Result<(), String> // Add the points in the MPC circuit let party0_values = share_authenticated_point_batch(my_vals, PARTY0, test_args); - let res = AuthenticatedStarkPointResult::batch_sub_public(&party0_values, &plaintext_values); - let res_open = await_batch_result_with_error( - AuthenticatedStarkPointResult::open_authenticated_batch(&res), - )?; + let res = AuthenticatedPointResult::batch_sub_public(&party0_values, &plaintext_values); + let res_open = + await_batch_result_with_error(AuthenticatedPointResult::open_authenticated_batch(&res))?; assert_point_batches_eq(res_open, expected_result) } diff --git a/integration/authenticated_scalar.rs b/integration/authenticated_scalar.rs index 2e98dba..68c2101 100644 --- a/integration/authenticated_scalar.rs +++ b/integration/authenticated_scalar.rs @@ -18,7 +18,7 @@ use crate::{ await_result, await_result_batch, await_result_with_error, share_authenticated_scalar, share_authenticated_scalar_batch, share_plaintext_value, share_plaintext_values_batch, }, - IntegrationTest, IntegrationTestArgs, + IntegrationTest, IntegrationTestArgs, TestScalar, }; // ----------- @@ -108,7 +108,7 @@ fn test_add_public_value(test_args: &IntegrationTestArgs) -> Result<(), String> let party0_value = share_plaintext_value(my_value.clone(), PARTY0, &test_args.fabric); let party1_value = share_plaintext_value(my_value, PARTY1, &test_args.fabric); - let plaintext_constant: Scalar = await_result(party1_value); + let plaintext_constant: TestScalar = await_result(party1_value); let expected_result = await_result(party0_value) + plaintext_constant; // Compute the result in the MPC circuit @@ -213,7 +213,7 @@ fn test_sub_public_scalar(test_args: &IntegrationTestArgs) -> Result<(), String> let party0_value = share_plaintext_value(my_value.clone(), PARTY0, &test_args.fabric); let party1_value = share_plaintext_value(my_value, PARTY1, &test_args.fabric); - let plaintext_constant: Scalar = await_result(party1_value); + let plaintext_constant: TestScalar = await_result(party1_value); let expected_result = await_result(party0_value) - plaintext_constant; // Compute the result in the MPC circuit @@ -365,7 +365,7 @@ fn test_mul_public_scalar(test_args: &IntegrationTestArgs) -> Result<(), String> let party0_value = share_plaintext_value(my_value.clone(), PARTY0, &test_args.fabric); let party1_value = share_plaintext_value(my_value, PARTY1, &test_args.fabric); - let plaintext_constant: Scalar = await_result(party1_value); + let plaintext_constant: TestScalar = await_result(party1_value); let expected_result = await_result(party0_value) * plaintext_constant; // Compute the result in the MPC circuit diff --git a/integration/circuits.rs b/integration/circuits.rs index 9cf1cb2..5cfe1a7 100644 --- a/integration/circuits.rs +++ b/integration/circuits.rs @@ -3,9 +3,8 @@ use itertools::Itertools; use mpc_stark::{ algebra::{ - authenticated_scalar::AuthenticatedScalarResult, - authenticated_stark_point::AuthenticatedStarkPointResult, scalar::Scalar, - stark_curve::StarkPoint, + authenticated_curve::AuthenticatedPointResult, + authenticated_scalar::AuthenticatedScalarResult, scalar::Scalar, }, random_point, PARTY0, PARTY1, }; @@ -16,7 +15,7 @@ use crate::{ assert_points_eq, assert_scalars_eq, await_result, await_result_batch, share_plaintext_value, share_plaintext_values_batch, }, - IntegrationTest, IntegrationTestArgs, + IntegrationTest, IntegrationTestArgs, TestCurve, TestCurvePoint, TestScalar, }; /// Tests an inner product between two vectors of shared scalars @@ -40,7 +39,7 @@ fn test_inner_product(test_args: &IntegrationTestArgs) -> Result<(), String> { let b_plaintext = await_result_batch(&share_plaintext_values_batch(&allocd_vals, PARTY1, fabric)); - let expected_res: Scalar = a_plaintext + let expected_res: TestScalar = a_plaintext .iter() .zip(b_plaintext) .map(|(a, b)| a * b) @@ -57,7 +56,8 @@ fn test_inner_product(test_args: &IntegrationTestArgs) -> Result<(), String> { .collect_vec(); // Compute the inner product - let res: AuthenticatedScalarResult = a.iter().zip(b.iter()).map(|(a, b)| a * b).sum(); + let res: AuthenticatedScalarResult = + a.iter().zip(b.iter()).map(|(a, b)| a * b).sum(); let res_open = await_result(res.open_authenticated()) .map_err(|err| format!("error opening result: {err:?}"))?; @@ -96,7 +96,7 @@ fn test_msm(test_args: &IntegrationTestArgs) -> Result<(), String> { fabric, )); - let expected_res = StarkPoint::msm(&plaintext_scalars, &plaintext_points); + let expected_res = TestCurvePoint::msm(&plaintext_scalars, &plaintext_points); // Share the values in an MPC circuit let shared_scalars = my_scalars @@ -109,7 +109,7 @@ fn test_msm(test_args: &IntegrationTestArgs) -> Result<(), String> { .collect_vec(); // Compare results - let res = AuthenticatedStarkPointResult::msm(&shared_scalars, &shared_points); + let res = AuthenticatedPointResult::msm(&shared_scalars, &shared_points); let res_open = await_result(res.open_authenticated()) .map_err(|err| format!("error opening msm result: {err:?}"))?; diff --git a/integration/fabric.rs b/integration/fabric.rs index 9158d13..8c58b71 100644 --- a/integration/fabric.rs +++ b/integration/fabric.rs @@ -21,14 +21,14 @@ fn test_fabric_share_and_open(test_args: &IntegrationTestArgs) -> Result<(), Str let party0_value_opened = party0_value.open(); let party0_res = await_result(party0_value_opened); - assert_scalars_eq(party0_res, Scalar::from(0))?; + assert_scalars_eq(party0_res, Scalar::from(0u8))?; // Party 1 let party1_value = share_scalar(my_party_id, PARTY1, test_args); let party1_value_opened = party1_value.open(); let party1_res = await_result(party1_value_opened); - assert_scalars_eq(party1_res, Scalar::from(1)) + assert_scalars_eq(party1_res, Scalar::from(1u8)) } inventory::submit!(IntegrationTest { diff --git a/integration/helpers.rs b/integration/helpers.rs index 79b0c21..205c20c 100644 --- a/integration/helpers.rs +++ b/integration/helpers.rs @@ -6,9 +6,9 @@ use futures::{future::join_all, Future}; use itertools::Itertools; use mpc_stark::{ algebra::{ - authenticated_scalar::AuthenticatedScalarResult, - authenticated_stark_point::AuthenticatedStarkPointResult, mpc_scalar::MpcScalarResult, - mpc_stark_point::MpcStarkPointResult, scalar::Scalar, stark_curve::StarkPoint, + authenticated_curve::AuthenticatedPointResult, + authenticated_scalar::AuthenticatedScalarResult, mpc_curve::MpcPointResult, + mpc_scalar::MpcScalarResult, scalar::Scalar, }, beaver::SharedValueSource, network::{NetworkPayload, PartyId}, @@ -20,11 +20,11 @@ use tokio::runtime::Handle; // | Helpers | // ----------- -use crate::IntegrationTestArgs; +use crate::{IntegrationTestArgs, TestCurve, TestCurvePoint, TestScalar}; /// Compares two scalars, returning a result that can be propagated up an integration test /// stack in the case that the scalars are not equal -pub(crate) fn assert_scalars_eq(a: Scalar, b: Scalar) -> Result<(), String> { +pub(crate) fn assert_scalars_eq(a: TestScalar, b: TestScalar) -> Result<(), String> { if a == b { Ok(()) } else { @@ -33,7 +33,10 @@ pub(crate) fn assert_scalars_eq(a: Scalar, b: Scalar) -> Result<(), String> { } /// Assert a batch of scalars equal one another -pub(crate) fn assert_scalar_batches_eq(a: Vec, b: Vec) -> Result<(), String> { +pub(crate) fn assert_scalar_batches_eq( + a: Vec, + b: Vec, +) -> Result<(), String> { if a.len() != b.len() { return Err(format!("Lengths differ: {a:?} != {b:?}")); } @@ -47,7 +50,7 @@ pub(crate) fn assert_scalar_batches_eq(a: Vec, b: Vec) -> Result /// Compares two points, returning a result that can be propagated up an integration test /// stack in the case that the points are not equal -pub(crate) fn assert_points_eq(a: StarkPoint, b: StarkPoint) -> Result<(), String> { +pub(crate) fn assert_points_eq(a: TestCurvePoint, b: TestCurvePoint) -> Result<(), String> { if a == b { Ok(()) } else { @@ -57,8 +60,8 @@ pub(crate) fn assert_points_eq(a: StarkPoint, b: StarkPoint) -> Result<(), Strin /// Compares two batches of points pub(crate) fn assert_point_batches_eq( - a: Vec, - b: Vec, + a: Vec, + b: Vec, ) -> Result<(), String> { if a.len() != b.len() { return Err(format!("Lengths differ: {a:?} != {b:?}")); @@ -116,20 +119,20 @@ where /// Send or receive a secret shared scalar from the given party pub(crate) fn share_scalar( - value: Scalar, + value: TestScalar, sender: PartyId, test_args: &IntegrationTestArgs, -) -> MpcScalarResult { +) -> MpcScalarResult { let authenticated_value = test_args.fabric.share_scalar(value, sender); authenticated_value.mpc_share() } /// Share a batch of scalars pub(crate) fn share_scalar_batch( - values: Vec, + values: Vec, sender: PartyId, test_args: &IntegrationTestArgs, -) -> Vec { +) -> Vec> { test_args .fabric .batch_share_scalar(values, sender) @@ -140,10 +143,10 @@ pub(crate) fn share_scalar_batch( /// Send or receive a secret shared point from the given party pub(crate) fn share_point( - value: StarkPoint, + value: TestCurvePoint, sender: PartyId, test_args: &IntegrationTestArgs, -) -> MpcStarkPointResult { +) -> MpcPointResult { // Share the point then cast to an `MpcStarkPoint` let authenticated_point = share_authenticated_point(value, sender, test_args); authenticated_point.mpc_share() @@ -151,10 +154,10 @@ pub(crate) fn share_point( /// Share a batch of points pub(crate) fn share_point_batch( - values: Vec, + values: Vec, sender: PartyId, test_args: &IntegrationTestArgs, -) -> Vec { +) -> Vec> { values .into_iter() .map(|point| share_point(point, sender, test_args)) @@ -163,46 +166,48 @@ pub(crate) fn share_point_batch( /// Send or receive a secret shared scalar from the given party and allocate it as an authenticated value pub(crate) fn share_authenticated_scalar( - value: Scalar, + value: TestScalar, sender: PartyId, test_args: &IntegrationTestArgs, -) -> AuthenticatedScalarResult { +) -> AuthenticatedScalarResult { test_args.fabric.share_scalar(value, sender) } /// Send or receive a batch of secret shared scalars from the given party and allocate them as authenticated values pub(crate) fn share_authenticated_scalar_batch( - values: Vec, + values: Vec, sender: PartyId, test_args: &IntegrationTestArgs, -) -> Vec { +) -> Vec> { test_args.fabric.batch_share_scalar(values, sender) } /// Send or receive a secret shared point from the given party and allocate it as an authenticated value pub(crate) fn share_authenticated_point( - value: StarkPoint, + value: TestCurvePoint, sender: PartyId, test_args: &IntegrationTestArgs, -) -> AuthenticatedStarkPointResult { +) -> AuthenticatedPointResult { test_args.fabric.share_point(value, sender) } /// Send or receive a batch of secret shared points from the given party and allocate them as authenticated values pub(crate) fn share_authenticated_point_batch( - values: Vec, + values: Vec, sender: PartyId, test_args: &IntegrationTestArgs, -) -> Vec { +) -> Vec> { test_args.fabric.batch_share_point(values, sender) } /// Share a value with the counterparty by sender ID, the sender sends and the receiver receives -pub(crate) fn share_plaintext_value + Into>( - value: ResultHandle, +pub(crate) fn share_plaintext_value< + T: From> + Into>, +>( + value: ResultHandle, sender: PartyId, - fabric: &MpcFabric, -) -> ResultHandle { + fabric: &MpcFabric, +) -> ResultHandle { if fabric.party_id() == sender { fabric.send_value(value) } else { @@ -211,11 +216,13 @@ pub(crate) fn share_plaintext_value + Into> } /// Share a batch of values in the plaintext -pub(crate) fn share_plaintext_values_batch + Into + Clone>( - values: &[ResultHandle], +pub(crate) fn share_plaintext_values_batch< + T: From> + Into> + Clone, +>( + values: &[ResultHandle], sender: PartyId, - fabric: &MpcFabric, -) -> Vec> { + fabric: &MpcFabric, +) -> Vec> { values .iter() .map(|v| share_plaintext_value(v.clone(), sender, fabric)) @@ -240,14 +247,14 @@ impl PartyIDBeaverSource { /// The PartyIDBeaverSource returns beaver triplets split statically between the /// parties. We assume a = 2, b = 3 ==> c = 6. [a] = (1, 1); [b] = (3, 0) [c] = (2, 4) -impl SharedValueSource for PartyIDBeaverSource { - fn next_shared_bit(&mut self) -> Scalar { +impl SharedValueSource for PartyIDBeaverSource { + fn next_shared_bit(&mut self) -> TestScalar { // Simply output partyID, assume partyID \in {0, 1} assert!(self.party_id == 0 || self.party_id == 1); Scalar::from(self.party_id) } - fn next_triplet(&mut self) -> (Scalar, Scalar, Scalar) { + fn next_triplet(&mut self) -> (TestScalar, TestScalar, TestScalar) { if self.party_id == 0 { (Scalar::from(1u64), Scalar::from(3u64), Scalar::from(2u64)) } else { @@ -255,11 +262,11 @@ impl SharedValueSource for PartyIDBeaverSource { } } - fn next_shared_inverse_pair(&mut self) -> (Scalar, Scalar) { - (Scalar::from(1), Scalar::from(1)) + fn next_shared_inverse_pair(&mut self) -> (TestScalar, TestScalar) { + (Scalar::from(1u8), Scalar::from(1u8)) } - fn next_shared_value(&mut self) -> Scalar { + fn next_shared_value(&mut self) -> TestScalar { Scalar::from(self.party_id) } } diff --git a/integration/main.rs b/integration/main.rs index 4c5ff36..2c2a919 100644 --- a/integration/main.rs +++ b/integration/main.rs @@ -1,5 +1,7 @@ use std::{borrow::Borrow, io::Write, net::SocketAddr, process::exit, thread, time::Duration}; +use ark_curve25519::Curve25519Config; +use ark_ec::twisted_edwards::Projective; use clap::Parser; use colored::Colorize; use dns_lookup::lookup_host; @@ -7,28 +9,36 @@ use env_logger::Builder; use futures::{SinkExt, StreamExt}; use helpers::PartyIDBeaverSource; use mpc_stark::{ + algebra::{curve::CurvePoint, scalar::Scalar}, network::{NetworkOutbound, NetworkPayload, QuicTwoPartyNet}, MpcFabric, PARTY0, }; use tokio::runtime::{Builder as RuntimeBuilder, Handle}; use tracing::log::{self, LevelFilter}; +mod authenticated_curve; mod authenticated_scalar; -mod authenticated_stark_point; mod circuits; mod fabric; mod helpers; +mod mpc_curve; mod mpc_scalar; -mod mpc_stark_point; /// The amount of time to sleep after sending a shutdown const SHUTDOWN_TIMEOUT_MS: u64 = 3_000; // 3 seconds +/// The curve used for testing, set to curve25519 +pub type TestCurve = Projective; +/// The curve point type used for testing +pub type TestCurvePoint = CurvePoint; +/// The scalar point ype used for testing +pub type TestScalar = Scalar; + /// Integration test arguments, common to all tests #[derive(Clone, Debug)] struct IntegrationTestArgs { party_id: u64, - fabric: MpcFabric, + fabric: MpcFabric, } /// Integration test format diff --git a/integration/mpc_stark_point.rs b/integration/mpc_curve.rs similarity index 91% rename from integration/mpc_stark_point.rs rename to integration/mpc_curve.rs index a21cbf8..4ae0766 100644 --- a/integration/mpc_stark_point.rs +++ b/integration/mpc_curve.rs @@ -3,9 +3,9 @@ use itertools::Itertools; use mpc_stark::{ algebra::{ - mpc_stark_point::MpcStarkPointResult, + curve::CurvePointResult, + mpc_curve::MpcPointResult, scalar::{Scalar, ScalarResult}, - stark_curve::StarkPointResult, }, random_point, PARTY0, PARTY1, }; @@ -17,7 +17,7 @@ use crate::{ share_plaintext_value, share_plaintext_values_batch, share_point, share_point_batch, share_scalar, share_scalar_batch, }, - IntegrationTest, IntegrationTestArgs, + IntegrationTest, IntegrationTestArgs, TestCurve, }; /// Test addition of `MpcStarkPoint` types @@ -89,13 +89,13 @@ fn test_batch_add(test_args: &IntegrationTestArgs) -> Result<(), String> { let party0_values = share_point_batch(points.clone(), PARTY0, test_args); let party1_values = share_point_batch(points, PARTY1, test_args); - let res = MpcStarkPointResult::batch_add(&party0_values, &party1_values); - let opened_res = await_result_batch(&MpcStarkPointResult::open_batch(&res)); + let res = MpcPointResult::batch_add(&party0_values, &party1_values); + let opened_res = await_result_batch(&MpcPointResult::open_batch(&res)); assert_point_batches_eq(opened_res, expected_result) } -/// Tests addition between a batch of `MpcStarkPointResult`s and `StarkPointResult`s +/// Tests addition between a batch of `MpcPointResult`s and `StarkPointResult`s fn test_batch_add_public(test_args: &IntegrationTestArgs) -> Result<(), String> { let n = 10; let fabric = &test_args.fabric; @@ -114,8 +114,8 @@ fn test_batch_add_public(test_args: &IntegrationTestArgs) -> Result<(), String> // Secret share the values and add them together in the MPC circuit let party0_values = share_point_batch(points, PARTY0, test_args); - let res = MpcStarkPointResult::batch_add_public(&party0_values, &plaintext_values); - let res_open = await_result_batch(&MpcStarkPointResult::open_batch(&res)); + let res = MpcPointResult::batch_add_public(&party0_values, &plaintext_values); + let res_open = await_result_batch(&MpcPointResult::open_batch(&res)); assert_point_batches_eq(res_open, expected_result) } @@ -189,8 +189,8 @@ fn test_batch_sub(test_args: &IntegrationTestArgs) -> Result<(), String> { let party0_values = share_point_batch(points.clone(), PARTY0, test_args); let party1_values = share_point_batch(points, PARTY1, test_args); - let res = MpcStarkPointResult::batch_sub(&party0_values, &party1_values); - let opened_res = await_result_batch(&MpcStarkPointResult::open_batch(&res)); + let res = MpcPointResult::batch_sub(&party0_values, &party1_values); + let opened_res = await_result_batch(&MpcPointResult::open_batch(&res)); assert_point_batches_eq(opened_res, expected_result) } @@ -214,8 +214,8 @@ fn test_batch_sub_public(test_args: &IntegrationTestArgs) -> Result<(), String> // Secret share the values and add them together in the MPC circuit let party0_values = share_point_batch(points, PARTY0, test_args); - let res = MpcStarkPointResult::batch_sub_public(&party0_values, &plaintext_values); - let res_open = await_result_batch(&MpcStarkPointResult::open_batch(&res)); + let res = MpcPointResult::batch_sub_public(&party0_values, &plaintext_values); + let res_open = await_result_batch(&MpcPointResult::open_batch(&res)); assert_point_batches_eq(res_open, expected_result) } @@ -258,8 +258,8 @@ fn test_batch_neg(test_args: &IntegrationTestArgs) -> Result<(), String> { // Secret share the values and add them together in the MPC circuit let party0_values = share_point_batch(points, PARTY0, test_args); - let res = MpcStarkPointResult::batch_neg(&party0_values); - let opened_res = await_result_batch(&MpcStarkPointResult::open_batch(&res)); + let res = MpcPointResult::batch_neg(&party0_values); + let opened_res = await_result_batch(&MpcPointResult::open_batch(&res)); assert_point_batches_eq(opened_res, expected_result) } @@ -273,12 +273,12 @@ fn test_mul(test_args: &IntegrationTestArgs) -> Result<(), String> { let scalar = Scalar::random(&mut rng); // Share the values with the counterparty - let plaintext_point: StarkPointResult = share_plaintext_value( + let plaintext_point: CurvePointResult = share_plaintext_value( test_args.fabric.allocate_point(point), PARTY0, &test_args.fabric, ); - let plaintext_scalar: ScalarResult = share_plaintext_value( + let plaintext_scalar: ScalarResult = share_plaintext_value( test_args.fabric.allocate_scalar(scalar), PARTY1, &test_args.fabric, @@ -311,7 +311,7 @@ fn test_mul_scalar_constant(test_args: &IntegrationTestArgs) -> Result<(), Strin PARTY0, &test_args.fabric, ); - let plaintext_scalar: ScalarResult = share_plaintext_value( + let plaintext_scalar: ScalarResult = share_plaintext_value( test_args.fabric.allocate_scalar(scalar), PARTY1, &test_args.fabric, @@ -361,13 +361,13 @@ fn test_batch_mul(test_args: &IntegrationTestArgs) -> Result<(), String> { let party0_values = share_point_batch(points, PARTY0, test_args); let party1_values = share_scalar_batch(scalars, PARTY1, test_args); - let res = MpcStarkPointResult::batch_mul(&party1_values, &party0_values); - let opened_res = await_result_batch(&MpcStarkPointResult::open_batch(&res)); + let res = MpcPointResult::batch_mul(&party1_values, &party0_values); + let opened_res = await_result_batch(&MpcPointResult::open_batch(&res)); assert_point_batches_eq(opened_res, expected_result) } -/// Test multiplication of a batch of `MpcStarkPointResult`s with `ScalarResult`s +/// Test multiplication of a batch of `MpcPointResult`s with `ScalarResult`s fn test_batch_mul_public(test_args: &IntegrationTestArgs) -> Result<(), String> { let n = 10; let mut rng = thread_rng(); @@ -390,8 +390,8 @@ fn test_batch_mul_public(test_args: &IntegrationTestArgs) -> Result<(), String> // Secret share the values and add them together in the MPC circuit let party0_values = share_point_batch(points, PARTY0, test_args); - let res = MpcStarkPointResult::batch_mul_public(&plaintext_values, &party0_values); - let res_open = await_result_batch(&MpcStarkPointResult::open_batch(&res)); + let res = MpcPointResult::batch_mul_public(&plaintext_values, &party0_values); + let res_open = await_result_batch(&MpcPointResult::open_batch(&res)); assert_point_batches_eq(res_open, expected_result) } diff --git a/integration/mpc_scalar.rs b/integration/mpc_scalar.rs index bcbf1c7..fd54527 100644 --- a/integration/mpc_scalar.rs +++ b/integration/mpc_scalar.rs @@ -1,8 +1,11 @@ //! Defines unit tests for `MpcScalarResult` types use itertools::Itertools; use mpc_stark::{ - algebra::{mpc_scalar::MpcScalarResult, scalar::Scalar}, - ResultHandle, PARTY0, PARTY1, + algebra::{ + mpc_scalar::MpcScalarResult, + scalar::{Scalar, ScalarResult}, + }, + PARTY0, PARTY1, }; use rand::thread_rng; use std::ops::Neg; @@ -12,7 +15,7 @@ use crate::{ assert_scalar_batches_eq, assert_scalars_eq, await_result, await_result_batch, share_plaintext_value, share_plaintext_values_batch, share_scalar, share_scalar_batch, }, - IntegrationTest, IntegrationTestArgs, + IntegrationTest, IntegrationTestArgs, TestCurve, }; /// Test addition of `MpcScalarResult` types @@ -20,7 +23,7 @@ fn test_add(test_args: &IntegrationTestArgs) -> Result<(), String> { // Each party allocates a random value let mut rng = thread_rng(); let val = Scalar::random(&mut rng); - let my_value: ResultHandle = test_args.fabric.allocate_scalar(val); + let my_value: ScalarResult = test_args.fabric.allocate_scalar(val); // Share the value with the counterparty let party0_value = share_plaintext_value(my_value.clone(), PARTY0, &test_args.fabric); @@ -46,7 +49,7 @@ fn test_add_scalar_constant(test_args: &IntegrationTestArgs) -> Result<(), Strin // Each party allocates a random value let mut rng = thread_rng(); let val = Scalar::random(&mut rng); - let my_value: ResultHandle = test_args.fabric.allocate_scalar(val); + let my_value: ScalarResult = test_args.fabric.allocate_scalar(val); // Share the value with the counterparty let party0_value = share_plaintext_value(my_value.clone(), PARTY0, &test_args.fabric); diff --git a/src/algebra/curve.rs b/src/algebra/curve.rs index aa9a566..656a42a 100644 --- a/src/algebra/curve.rs +++ b/src/algebra/curve.rs @@ -862,59 +862,74 @@ impl CurvePointResult { /// https://github.com/xJonathanLEI/starknet-rs #[cfg(test)] mod test { - use crate::algebra::test_helper::random_point; + use rand::thread_rng; + + use crate::{algebra::test_helper::TestCurve, test_helpers::mock_fabric}; use super::*; - /// Test that the generators are the same between the two curve representations - #[test] - fn test_generators() { - // let generator_1 = CurvePoint::generator(); - // let generator_2 = ProjectivePoint::from_affine_point(&GENERATOR); + /// A curve point on the test curve + pub type TestCurvePoint = CurvePoint; - // assert!(compare_points(&generator_1, &generator_2)); + /// Generate a random point, by multiplying the basepoint with a random scalar + pub fn random_point() -> TestCurvePoint { + let mut rng = thread_rng(); + let scalar = Scalar::random(&mut rng); + let point = TestCurvePoint::generator() * scalar; + point * scalar } /// Tests point addition - #[test] - fn test_point_addition() { - // let p1 = random_point(); - // let q1 = random_point(); + #[tokio::test] + async fn test_point_addition() { + let fabric = mock_fabric(); - // let p2 = arkworks_point_to_starknet(&p1); - // let q2 = arkworks_point_to_starknet(&q1); + let p1 = random_point(); + let p2 = random_point(); - // let r1 = p1 + q1; + let p1_res = fabric.allocate_point(p1); + let p2_res = fabric.allocate_point(p2); - // // Only `AddAssign` is implemented on `ProjectivePoint` - // let mut r2 = p2; - // r2 += &q2; + let res = (p1_res + p2_res).await; + let expected_res = p1 + p2; - // assert!(compare_points(&r1, &r2)); + assert_eq!(res, expected_res); + fabric.shutdown(); } /// Tests scalar multiplication - #[test] - fn test_scalar_mul() { - // let mut rng = thread_rng(); - // let s1 = Scalar::random(&mut rng); - // let p1 = random_point(); + #[tokio::test] + async fn test_scalar_mul() { + let fabric = mock_fabric(); + + let mut rng = thread_rng(); + let s1 = Scalar::::random(&mut rng); + let p1 = random_point(); - // let s2 = prime_field_to_starknet_felt(&s1.0); - // let p2 = arkworks_point_to_starknet(&p1); + let s1_res = fabric.allocate_scalar(s1); + let p1_res = fabric.allocate_point(p1); - // let r1 = p1 * s1; - // let r2 = starknet_rs_scalar_mul(&s2, &p2); + let res = (s1_res * p1_res).await; + let expected_res = s1 * p1; - // assert!(compare_points(&r1, &r2)); + assert_eq!(res, expected_res); + fabric.shutdown(); } /// Tests addition with the additive identity - #[test] - fn test_additive_identity() { + #[tokio::test] + async fn test_additive_identity() { + let fabric = mock_fabric(); + let p1 = random_point(); - let res = p1 + CurvePoint::identity(); - assert_eq!(p1, res); + let p1_res = fabric.allocate_point(p1); + let identity_res = fabric.curve_identity(); + + let res = (p1_res + identity_res).await; + let expected_res = p1; + + assert_eq!(res, expected_res); + fabric.shutdown(); } } diff --git a/src/algebra/mod.rs b/src/algebra/mod.rs index 9d41d3b..398c96c 100644 --- a/src/algebra/mod.rs +++ b/src/algebra/mod.rs @@ -11,32 +11,8 @@ pub mod scalar; /// Helpers useful for testing throughout the `algebra` module #[cfg(any(test, feature = "test_helpers"))] pub(crate) mod test_helper { - use super::{curve::CurvePoint, scalar::Scalar}; - use ark_curve25519::EdwardsProjective as Curve25519Projective; - use ark_ff::PrimeField; - use num_bigint::BigUint; - use rand::thread_rng; - - // ----------- - // | Helpers | - // ----------- /// A curve used for testing algebra implementations, set to curve25519 pub type TestCurve = Curve25519Projective; - /// A curve point on the test curve - pub type TestCurvePoint = CurvePoint; - - /// Generate a random point, by multiplying the basepoint with a random scalar - pub fn random_point() -> TestCurvePoint { - let mut rng = thread_rng(); - let scalar = Scalar::random(&mut rng); - let point = TestCurvePoint::generator() * scalar; - point * scalar - } - - /// Convert a prime field element to a `BigUint` - pub fn prime_field_to_biguint(val: &F) -> BigUint { - (*val).into() - } } diff --git a/src/algebra/scalar.rs b/src/algebra/scalar.rs index b0f0a2d..035f4da 100644 --- a/src/algebra/scalar.rs +++ b/src/algebra/scalar.rs @@ -12,6 +12,7 @@ use std::{ use ark_ec::CurveGroup; use ark_ff::{batch_inversion, Field, PrimeField}; +use ark_std::UniformRand; use itertools::Itertools; use num_bigint::BigUint; use rand::{CryptoRng, RngCore}; @@ -65,11 +66,7 @@ impl Scalar { /// /// TODO: Validate that this gives a uniform distribution over the field pub fn random(rng: &mut R) -> Self { - let mut random_bytes = vec![0u8; n_bytes_field::()]; - rng.fill_bytes(&mut random_bytes); - - let val = C::ScalarField::from_random_bytes(&random_bytes).unwrap(); - Self(val) + Self(C::ScalarField::rand(rng)) } /// Compute the multiplicative inverse of the scalar in its field diff --git a/src/fabric/executor.rs b/src/fabric/executor.rs index b660fb4..dbed345 100644 --- a/src/fabric/executor.rs +++ b/src/fabric/executor.rs @@ -5,6 +5,9 @@ use std::collections::HashMap; use std::fmt::Debug; use std::sync::Arc; +#[cfg(feature = "stats")] +use std::fmt::{Formatter, Result as FmtResult}; + use ark_ec::CurveGroup; use crossbeam::queue::SegQueue; use itertools::Itertools; @@ -63,7 +66,7 @@ impl ExecutorStats { } /// Add an operation to the executor's depth map - pub fn new_operation(&mut self, op: &Operation, from_network_op: bool) { + pub fn new_operation(&mut self, op: &Operation, from_network_op: bool) { let max_dep = op .args .iter() @@ -274,7 +277,7 @@ impl Executor { /// Record the depth of an operation in the circuit #[cfg(feature = "stats")] - fn record_op_depth(&mut self, op: &Operation) { + fn record_op_depth(&mut self, op: &Operation) { let is_network_op = matches!(op.op_type, OperationType::Network { .. }); self.stats.new_operation(op, is_network_op); } diff --git a/src/fabric/network_sender.rs b/src/fabric/network_sender.rs index d2d9311..8d5fdfb 100644 --- a/src/fabric/network_sender.rs +++ b/src/fabric/network_sender.rs @@ -40,6 +40,7 @@ pub struct NetworkStats { pub messages_received: AtomicUsize, } +#[allow(unused)] impl NetworkStats { /// Increment the number of bytes sent pub fn increment_bytes_sent(&self, bytes: usize) { @@ -140,7 +141,7 @@ impl + 'static> NetworkSender { async fn read_loop( mut network_stream: SplitStream, result_queue: Arc>>, - stats: Arc, + #[allow(unused)] stats: Arc, ) -> MpcNetworkError { while let Some(Ok(msg)) = network_stream.next().await { #[cfg(feature = "stats")] @@ -164,7 +165,7 @@ impl + 'static> NetworkSender { async fn write_loop( outbound_stream: KanalReceiver>, mut network: SplitSink>, - stats: Arc, + #[allow(unused)] stats: Arc, ) -> MpcNetworkError { while let Ok(msg) = outbound_stream.recv().await { #[cfg(feature = "stats")] diff --git a/src/network/quic.rs b/src/network/quic.rs index 71f4742..3cb472e 100644 --- a/src/network/quic.rs +++ b/src/network/quic.rs @@ -277,7 +277,7 @@ where { type Item = Result, MpcNetworkError>; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Box::pin(self.get_mut().receive_message()) .as_mut() .poll(cx) @@ -291,7 +291,7 @@ where { type Error = MpcNetworkError; - fn start_send(mut self: Pin<&mut Self>, msg: NetworkOutbound) -> Result<(), Self::Error> { + fn start_send(self: Pin<&mut Self>, msg: NetworkOutbound) -> Result<(), Self::Error> { if !self.connected { return Err(MpcNetworkError::NetworkUninitialized); }