Skip to content

Commit

Permalink
integration, algebra: Fixup tests and get to passing state over gener…
Browse files Browse the repository at this point in the history
…al curves
  • Loading branch information
joeykraut committed Oct 10, 2023
1 parent 470bbfa commit 344ab6f
Show file tree
Hide file tree
Showing 15 changed files with 186 additions and 180 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ ark-curve25519 = { version = "0.4", optional = true }
ark-ec = { version = "0.4", features = ["parallel"] }
ark-ff = "0.4"
ark-serialize = "0.4"
ark-std = "0.4"
digest = "0.10"
num-bigint = "0.4"
rand = "0.8"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
use itertools::Itertools;
use mpc_stark::{
algebra::{
authenticated_stark_point::{
authenticated_curve::{
test_helpers::{modify_mac, modify_public_modifier, modify_share},
AuthenticatedStarkPointResult,
AuthenticatedPointResult,
},
scalar::Scalar,
},
Expand Down Expand Up @@ -144,10 +144,9 @@ fn test_batch_add(test_args: &IntegrationTestArgs) -> Result<(), String> {
let party0_values = share_authenticated_point_batch(my_vals.clone(), PARTY0, test_args);
let party1_values = share_authenticated_point_batch(my_vals, PARTY1, test_args);

let res = AuthenticatedStarkPointResult::batch_add(&party0_values, &party1_values);
let res_open = await_batch_result_with_error(
AuthenticatedStarkPointResult::open_authenticated_batch(&res),
)?;
let res = AuthenticatedPointResult::batch_add(&party0_values, &party1_values);
let res_open =
await_batch_result_with_error(AuthenticatedPointResult::open_authenticated_batch(&res))?;

assert_point_batches_eq(res_open, expected_result)
}
Expand All @@ -172,10 +171,9 @@ fn test_batch_add_public(test_args: &IntegrationTestArgs) -> Result<(), String>
// Add the points in the MPC circuit
let party0_values = share_authenticated_point_batch(my_vals, PARTY0, test_args);

let res = AuthenticatedStarkPointResult::batch_add_public(&party0_values, &plaintext_values);
let res_open = await_batch_result_with_error(
AuthenticatedStarkPointResult::open_authenticated_batch(&res),
)?;
let res = AuthenticatedPointResult::batch_add_public(&party0_values, &plaintext_values);
let res_open =
await_batch_result_with_error(AuthenticatedPointResult::open_authenticated_batch(&res))?;

assert_point_batches_eq(res_open, expected_result)
}
Expand Down Expand Up @@ -240,10 +238,9 @@ fn test_batch_sub(test_args: &IntegrationTestArgs) -> Result<(), String> {
let party0_values = share_authenticated_point_batch(my_vals.clone(), PARTY0, test_args);
let party1_values = share_authenticated_point_batch(my_vals, PARTY1, test_args);

let res = AuthenticatedStarkPointResult::batch_sub(&party0_values, &party1_values);
let res_open = await_batch_result_with_error(
AuthenticatedStarkPointResult::open_authenticated_batch(&res),
)?;
let res = AuthenticatedPointResult::batch_sub(&party0_values, &party1_values);
let res_open =
await_batch_result_with_error(AuthenticatedPointResult::open_authenticated_batch(&res))?;

assert_point_batches_eq(res_open, expected_result)
}
Expand All @@ -268,10 +265,9 @@ fn test_batch_sub_public(test_args: &IntegrationTestArgs) -> Result<(), String>
// Add the points in the MPC circuit
let party0_values = share_authenticated_point_batch(my_vals, PARTY0, test_args);

let res = AuthenticatedStarkPointResult::batch_sub_public(&party0_values, &plaintext_values);
let res_open = await_batch_result_with_error(
AuthenticatedStarkPointResult::open_authenticated_batch(&res),
)?;
let res = AuthenticatedPointResult::batch_sub_public(&party0_values, &plaintext_values);
let res_open =
await_batch_result_with_error(AuthenticatedPointResult::open_authenticated_batch(&res))?;

assert_point_batches_eq(res_open, expected_result)
}
Expand Down Expand Up @@ -310,10 +306,9 @@ fn test_batch_negation(test_args: &IntegrationTestArgs) -> Result<(), String> {

// Compute the expected result in an MPC circuit
let party0_values = share_authenticated_point_batch(my_values, PARTY0, test_args);
let res = AuthenticatedStarkPointResult::batch_neg(&party0_values);
let res_open = await_batch_result_with_error(
AuthenticatedStarkPointResult::open_authenticated_batch(&res),
)?;
let res = AuthenticatedPointResult::batch_neg(&party0_values);
let res_open =
await_batch_result_with_error(AuthenticatedPointResult::open_authenticated_batch(&res))?;

assert_point_batches_eq(expected_res, res_open)
}
Expand Down Expand Up @@ -382,10 +377,9 @@ fn test_batch_mul(test_args: &IntegrationTestArgs) -> Result<(), String> {
let party0_values = share_authenticated_point_batch(my_vals.clone(), PARTY0, test_args);
let party1_values = share_authenticated_point_batch(my_vals, PARTY1, test_args);

let res = AuthenticatedStarkPointResult::batch_sub(&party0_values, &party1_values);
let res_open = await_batch_result_with_error(
AuthenticatedStarkPointResult::open_authenticated_batch(&res),
)?;
let res = AuthenticatedPointResult::batch_sub(&party0_values, &party1_values);
let res_open =
await_batch_result_with_error(AuthenticatedPointResult::open_authenticated_batch(&res))?;

assert_point_batches_eq(res_open, expected_result)
}
Expand All @@ -410,10 +404,9 @@ fn test_batch_mul_public(test_args: &IntegrationTestArgs) -> Result<(), String>
// Add the points in the MPC circuit
let party0_values = share_authenticated_point_batch(my_vals, PARTY0, test_args);

let res = AuthenticatedStarkPointResult::batch_sub_public(&party0_values, &plaintext_values);
let res_open = await_batch_result_with_error(
AuthenticatedStarkPointResult::open_authenticated_batch(&res),
)?;
let res = AuthenticatedPointResult::batch_sub_public(&party0_values, &plaintext_values);
let res_open =
await_batch_result_with_error(AuthenticatedPointResult::open_authenticated_batch(&res))?;

assert_point_batches_eq(res_open, expected_result)
}
Expand Down
8 changes: 4 additions & 4 deletions integration/authenticated_scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::{
await_result, await_result_batch, await_result_with_error, share_authenticated_scalar,
share_authenticated_scalar_batch, share_plaintext_value, share_plaintext_values_batch,
},
IntegrationTest, IntegrationTestArgs,
IntegrationTest, IntegrationTestArgs, TestScalar,
};

// -----------
Expand Down Expand Up @@ -108,7 +108,7 @@ fn test_add_public_value(test_args: &IntegrationTestArgs) -> Result<(), String>
let party0_value = share_plaintext_value(my_value.clone(), PARTY0, &test_args.fabric);
let party1_value = share_plaintext_value(my_value, PARTY1, &test_args.fabric);

let plaintext_constant: Scalar = await_result(party1_value);
let plaintext_constant: TestScalar = await_result(party1_value);
let expected_result = await_result(party0_value) + plaintext_constant;

// Compute the result in the MPC circuit
Expand Down Expand Up @@ -213,7 +213,7 @@ fn test_sub_public_scalar(test_args: &IntegrationTestArgs) -> Result<(), String>
let party0_value = share_plaintext_value(my_value.clone(), PARTY0, &test_args.fabric);
let party1_value = share_plaintext_value(my_value, PARTY1, &test_args.fabric);

let plaintext_constant: Scalar = await_result(party1_value);
let plaintext_constant: TestScalar = await_result(party1_value);
let expected_result = await_result(party0_value) - plaintext_constant;

// Compute the result in the MPC circuit
Expand Down Expand Up @@ -365,7 +365,7 @@ fn test_mul_public_scalar(test_args: &IntegrationTestArgs) -> Result<(), String>
let party0_value = share_plaintext_value(my_value.clone(), PARTY0, &test_args.fabric);
let party1_value = share_plaintext_value(my_value, PARTY1, &test_args.fabric);

let plaintext_constant: Scalar = await_result(party1_value);
let plaintext_constant: TestScalar = await_result(party1_value);
let expected_result = await_result(party0_value) * plaintext_constant;

// Compute the result in the MPC circuit
Expand Down
16 changes: 8 additions & 8 deletions integration/circuits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
use itertools::Itertools;
use mpc_stark::{
algebra::{
authenticated_scalar::AuthenticatedScalarResult,
authenticated_stark_point::AuthenticatedStarkPointResult, scalar::Scalar,
stark_curve::StarkPoint,
authenticated_curve::AuthenticatedPointResult,
authenticated_scalar::AuthenticatedScalarResult, scalar::Scalar,
},
random_point, PARTY0, PARTY1,
};
Expand All @@ -16,7 +15,7 @@ use crate::{
assert_points_eq, assert_scalars_eq, await_result, await_result_batch,
share_plaintext_value, share_plaintext_values_batch,
},
IntegrationTest, IntegrationTestArgs,
IntegrationTest, IntegrationTestArgs, TestCurve, TestCurvePoint, TestScalar,
};

/// Tests an inner product between two vectors of shared scalars
Expand All @@ -40,7 +39,7 @@ fn test_inner_product(test_args: &IntegrationTestArgs) -> Result<(), String> {
let b_plaintext =
await_result_batch(&share_plaintext_values_batch(&allocd_vals, PARTY1, fabric));

let expected_res: Scalar = a_plaintext
let expected_res: TestScalar = a_plaintext
.iter()
.zip(b_plaintext)
.map(|(a, b)| a * b)
Expand All @@ -57,7 +56,8 @@ fn test_inner_product(test_args: &IntegrationTestArgs) -> Result<(), String> {
.collect_vec();

// Compute the inner product
let res: AuthenticatedScalarResult = a.iter().zip(b.iter()).map(|(a, b)| a * b).sum();
let res: AuthenticatedScalarResult<TestCurve> =
a.iter().zip(b.iter()).map(|(a, b)| a * b).sum();
let res_open = await_result(res.open_authenticated())
.map_err(|err| format!("error opening result: {err:?}"))?;

Expand Down Expand Up @@ -96,7 +96,7 @@ fn test_msm(test_args: &IntegrationTestArgs) -> Result<(), String> {
fabric,
));

let expected_res = StarkPoint::msm(&plaintext_scalars, &plaintext_points);
let expected_res = TestCurvePoint::msm(&plaintext_scalars, &plaintext_points);

// Share the values in an MPC circuit
let shared_scalars = my_scalars
Expand All @@ -109,7 +109,7 @@ fn test_msm(test_args: &IntegrationTestArgs) -> Result<(), String> {
.collect_vec();

// Compare results
let res = AuthenticatedStarkPointResult::msm(&shared_scalars, &shared_points);
let res = AuthenticatedPointResult::msm(&shared_scalars, &shared_points);
let res_open = await_result(res.open_authenticated())
.map_err(|err| format!("error opening msm result: {err:?}"))?;

Expand Down
4 changes: 2 additions & 2 deletions integration/fabric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ fn test_fabric_share_and_open(test_args: &IntegrationTestArgs) -> Result<(), Str
let party0_value_opened = party0_value.open();
let party0_res = await_result(party0_value_opened);

assert_scalars_eq(party0_res, Scalar::from(0))?;
assert_scalars_eq(party0_res, Scalar::from(0u8))?;

// Party 1
let party1_value = share_scalar(my_party_id, PARTY1, test_args);
let party1_value_opened = party1_value.open();
let party1_res = await_result(party1_value_opened);

assert_scalars_eq(party1_res, Scalar::from(1))
assert_scalars_eq(party1_res, Scalar::from(1u8))
}

inventory::submit!(IntegrationTest {
Expand Down
Loading

0 comments on commit 344ab6f

Please sign in to comment.