Skip to content

Commit

Permalink
chore: Address feedback related to fixed base msm formula (#278)
Browse files Browse the repository at this point in the history
* sign -> is_positive

* change new_differences -> denominators and add assert for BATCH_INVERSE_THRESHOLD
  • Loading branch information
kevaundray authored Sep 25, 2024
1 parent fce5942 commit a7ede22
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 17 deletions.
36 changes: 21 additions & 15 deletions cryptography/bls12_381/src/batch_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,35 +55,37 @@ pub fn batch_addition_binary_tree_stride(mut points: Vec<G1Affine>) -> G1Project
return G1Projective::identity();
}

let mut new_differences = Vec::with_capacity(points.len());

let mut denominators = Vec::with_capacity(points.len());
let mut sum = G1Projective::identity();

assert!(
BATCH_INVERSE_THRESHOLD >= 2,
"THRESHOLD cannot be below the number of points needed for group addition"
);
while points.len() > BATCH_INVERSE_THRESHOLD {
if points.len() % 2 != 0 {
sum += points
.pop()
.expect("infallible; since points has an odd length");
}
new_differences.clear();
denominators.clear();

for i in (0..=points.len() - 2).step_by(2) {
let p1 = points[i];
let p2 = points[i + 1];
new_differences.push(choose_add_or_double(p1, p2));
denominators.push(choose_add_or_double(p1, p2));
}

batch_inverse(&mut new_differences);

for (i, inv) in (0..=points.len() - 2).step_by(2).zip(&new_differences) {
batch_inverse(&mut denominators);
for (i, inv) in (0..=points.len() - 2).step_by(2).zip(&denominators) {
let p1 = points[i];
let p2 = points[i + 1];
points[i / 2] = point_add_double(p1, p2, inv);
}

// The latter half of the vector is now unused,
// all results are stored in the former half.
points.truncate(new_differences.len())
points.truncate(denominators.len())
}

for point in points {
Expand Down Expand Up @@ -127,11 +129,15 @@ pub fn multi_batch_addition_binary_tree_stride(
.sum()
}

let mut new_differences = Vec::with_capacity(max_bucket_length);
let mut denominators = Vec::with_capacity(max_bucket_length);
let mut total_amount_of_work = compute_threshold(&multi_points);

let mut sums = vec![G1Projective::identity(); multi_points.len()];

assert!(
BATCH_INVERSE_THRESHOLD >= 2,
"THRESHOLD cannot be below the number of points needed for group addition"
);
// TODO: total_amount_of_work does not seem to be changing performance that much
while total_amount_of_work > BATCH_INVERSE_THRESHOLD {
// For each point, we check if they are odd and pop off
Expand All @@ -143,7 +149,7 @@ pub fn multi_batch_addition_binary_tree_stride(
}
}

new_differences.clear();
denominators.clear();

// For each pair of points over all
// vectors, we collect them and put them in the
Expand All @@ -153,21 +159,21 @@ pub fn multi_batch_addition_binary_tree_stride(
continue;
}
for i in (0..=points.len() - 2).step_by(2) {
new_differences.push(choose_add_or_double(points[i], points[i + 1]));
denominators.push(choose_add_or_double(points[i], points[i + 1]));
}
}

batch_inverse_scratch_pad(&mut new_differences, &mut scratchpad);
batch_inverse_scratch_pad(&mut denominators, &mut scratchpad);

let mut new_differences_offset = 0;
let mut denominators_offset = 0;

for points in multi_points.iter_mut() {
if points.len() < 2 {
continue;
}
for (i, inv) in (0..=points.len() - 2)
.step_by(2)
.zip(&new_differences[new_differences_offset..])
.zip(&denominators[denominators_offset..])
{
let p1 = points[i];
let p2 = points[i + 1];
Expand All @@ -178,7 +184,7 @@ pub fn multi_batch_addition_binary_tree_stride(
// The latter half of the vector is now unused,
// all results are stored in the former half.
points.truncate(num_points);
new_differences_offset += num_points
denominators_offset += num_points
}

total_amount_of_work = compute_threshold(&multi_points);
Expand Down
4 changes: 2 additions & 2 deletions cryptography/bls12_381/src/fixed_base_msm_window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ impl FixedBaseMSMPrecompWindow {
if point_idx == 0 {
continue;
}
let sign = point_idx.is_positive();
let is_scalar_positive = point_idx.is_positive();
let point_idx = point_idx.unsigned_abs() as usize - 1;
let mut point = sub_table[point_idx];
if !sign {
if !is_scalar_positive {
point = -point;
}

Expand Down

0 comments on commit a7ede22

Please sign in to comment.