diff --git a/cryptography/bls12_381/Cargo.toml b/cryptography/bls12_381/Cargo.toml index 6c854c15..6435be53 100644 --- a/cryptography/bls12_381/Cargo.toml +++ b/cryptography/bls12_381/Cargo.toml @@ -19,17 +19,16 @@ blstrs = { version = "0.7.1", features = ["__private_bench"] } ff = "0.13.0" group = "0.13" pairing = { version = "0.23" } - # Transitively, we depend on subtle version >=2.5.0 # Adding the restrictions here codify it in rust-eth-kzg. # # See https://github.com/crate-crypto/rust-eth-kzg/issues/235 for more info # as to why we need to pull it in here, even though it is not used directly. subtle = { version = ">=2.5.0, <3.0" } - +rand = "0.8.4" +rayon = { workspace = true } [dev-dependencies] criterion = "0.5.1" -rand = "0.8.4" [features] blst-no-threads = ["blst/no-threads"] diff --git a/cryptography/bls12_381/benches/benchmark.rs b/cryptography/bls12_381/benches/benchmark.rs index 25803359..3e1d78ae 100644 --- a/cryptography/bls12_381/benches/benchmark.rs +++ b/cryptography/bls12_381/benches/benchmark.rs @@ -1,15 +1,23 @@ use blstrs::Scalar; +use crate_crypto_internal_eth_kzg_bls12_381::batch_add::batch_addition; +use crate_crypto_internal_eth_kzg_bls12_381::batch_add::batch_addition_diff_stride; +use crate_crypto_internal_eth_kzg_bls12_381::batch_add::multi_batch_addition; +use crate_crypto_internal_eth_kzg_bls12_381::batch_add::multi_batch_addition_diff_stride; +use crate_crypto_internal_eth_kzg_bls12_381::fixed_base_msm_blst::FixedBaseMultiMSMPrecompBLST; +// use crate_crypto_internal_eth_kzg_bls12_381::fixed_base_msm_pippenger::pippenger_fixed_base_msm_wnaf; use crate_crypto_internal_eth_kzg_bls12_381::{ batch_inversion, ff::Field, fixed_base_msm::{FixedBaseMSM, UsePrecomp}, + fixed_base_msm_blst::FixedBaseMSMPrecompBLST, fixed_base_msm_pippenger::FixedBaseMSMPippenger, g1_batch_normalize, g2_batch_normalize, group::Group, lincomb::{g1_lincomb, g1_lincomb_unsafe, g2_lincomb, g2_lincomb_unsafe}, - G1Projective, G2Projective, + G1Point, G1Projective, G2Projective, }; -use criterion::{criterion_group, criterion_main, Criterion}; + +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; pub fn batch_inversion(c: &mut Criterion) { const NUM_ELEMENTS: usize = 8192; @@ -26,22 +34,133 @@ pub fn batch_inversion(c: &mut Criterion) { } pub fn fixed_base_msm(c: &mut Criterion) { let length = 64; - let generators: Vec<_> = random_g1_points(length) - .into_iter() - .map(|p| p.into()) + // let generators: Vec> = random_g1_points(length) + // .into_iter() + // .map(|p| G1Point::from(p)) + // .collect(); + // let scalars: Vec<_> = random_scalars(length); + // let fbm = FixedBaseMSM::new(generators.clone(), UsePrecomp::Yes { width: 8 }); + + // c.bench_function("bls12_381 fixed_base_msm length=64 width=8", |b| { + // b.iter(|| fbm.msm(scalars.clone())) + // }); + + // let fixed_base_pip = FixedBaseMSMPippenger::new(&generators); + + // c.bench_function("bls12_381 fixed based pippenger algorithm wnaf", |b| { + // b.iter(|| fixed_base_pip.msm(&scalars)) + // }); + + // c.bench_function("bls12_381 fixed based pippenger algorithm", |b| { + // b.iter(|| fixed_base_pip.msm(&scalars)) + // }); + + // let mut group = c.benchmark_group("bls12_381 fixed base windowed algorithm"); + + // for window_size in 7..=14 { + // // Test window sizes from 2 to 10 + // // Create the FixedBaseMSMPrecompBLST instance outside the benchmarked portion + // let fixed_base = FixedBaseMSMPrecompBLST::new(&generators, window_size); + + // group.bench_with_input( + // BenchmarkId::new("window_size", window_size), + // &window_size, + // |b, &_| b.iter(|| black_box(fixed_base.msm(black_box(&scalars)))), + // ); + // } + // group.finish(); +} + +pub fn multi_fixed_base_msm(c: &mut Criterion) { + let length: usize = 64; + // let generators: Vec<_> = random_g1_points(length) + // .into_iter() + // .map(|p| p.into()) + // .collect(); + // let scalars: Vec<_> = random_scalars(length); + let num_sets = 128; + + let scalars_sets: Vec<_> = (0..num_sets).map(|_| random_scalars(length)).collect(); + let points_sets: Vec<_> = (0..num_sets) + .map(|_| { + random_g1_points(length) + .into_iter() + .map(|p| p.into()) + .collect() + }) .collect(); - let fbm = FixedBaseMSM::new(generators.clone(), UsePrecomp::Yes { width: 8 }); - let scalars: Vec<_> = random_scalars(length); - c.bench_function("bls12_381 fixed_base_msm length=64 width=8", |b| { - b.iter(|| fbm.msm(scalars.clone())) + // let fbm = FixedBaseMSM::new(generators.clone(), UsePrecomp::Yes { width: 8 }); + let multi_msm = FixedBaseMultiMSMPrecompBLST::new(points_sets, 8); + c.bench_function("bls12_381 fixed_base_multi_msm", |b| { + b.iter(|| multi_msm.multi_msm(scalars_sets.clone())) }); - let fixed_base_pip = FixedBaseMSMPippenger::new(&generators); + // let fixed_base_pip = FixedBaseMSMPippenger::new(&generators); - c.bench_function("bls12_381 fixed based pippenger algorithm", |b| { - b.iter(|| fixed_base_pip.msm(&scalars)) - }); + // c.bench_function("bls12_381 fixed based pippenger algorithm wnaf", |b| { + // b.iter(|| fixed_base_pip.msm(&scalars)) + // }); + + // c.bench_function("bls12_381 fixed based pippenger algorithm", |b| { + // b.iter(|| fixed_base_pip.msm(&scalars)) + // }); + + // let mut group = c.benchmark_group("bls12_381 fixed base windowed algorithm"); + + // for window_size in 7..=14 { + // // Test window sizes from 2 to 10 + // // Create the FixedBaseMSMPrecompBLST instance outside the benchmarked portion + // let fixed_base = FixedBaseMSMPrecompBLST::new(&generators, window_size); + + // group.bench_with_input( + // BenchmarkId::new("window_size", window_size), + // &window_size, + // |b, &_| b.iter(|| black_box(fixed_base.msm(black_box(&scalars)))), + // ); + // } + // group.finish(); +} + +pub fn bench_batch_addition(c: &mut Criterion) { + let mut group = c.benchmark_group("batch addition"); + + for length in [64, 128, 256, 512, 1024] { + let vector_length = 8; + + let generators: Vec<_> = (0..vector_length) + .map(|_| { + random_g1_points(length) + .into_iter() + .map(|p| p.into()) + .collect() + }) + .collect(); + + group.bench_with_input( + BenchmarkId::new("length-normal", length), + &length, + |b, &_| b.iter(|| black_box(multi_batch_addition(generators.clone()))), + ); + + group.bench_with_input( + BenchmarkId::new("length-naive", length), + &length, + |b, &_| { + b.iter(|| { + for point in &generators { + black_box(batch_addition_diff_stride(point.clone())); + } + }) + }, + ); + group.bench_with_input( + BenchmarkId::new("length-diff-stride", length), + &length, + |b, &_| b.iter(|| black_box(multi_batch_addition_diff_stride(generators.clone()))), + ); + } + group.finish(); } pub fn bench_msm(c: &mut Criterion) { @@ -97,9 +216,10 @@ fn random_g2_points(size: usize) -> Vec { criterion_group!( benches, - batch_inversion, - fixed_base_msm, - bench_msm, - fixed_base_msm + // batch_inversion, + // fixed_base_msm, + // bench_msm, + // fixed_base_msm // bench_batch_addition + multi_fixed_base_msm ); criterion_main!(benches); diff --git a/cryptography/bls12_381/src/batch_add.rs b/cryptography/bls12_381/src/batch_add.rs index ecf5a3a0..a894c19e 100644 --- a/cryptography/bls12_381/src/batch_add.rs +++ b/cryptography/bls12_381/src/batch_add.rs @@ -1,5 +1,11 @@ -use crate::batch_inversion::{batch_inverse, batch_inverse_scratch_pad}; -use blstrs::{Fp, G1Affine}; +use crate::{ + batch_add_blst::G1AffineInv, + batch_inversion::{batch_inverse, batch_inverse_scratch_pad}, +}; +use ff::Field; + +use blstrs::{Fp, G1Affine, G1Projective}; +use group::Group; /// Adds multiple points together in affine representation, batching the inversions pub fn batch_addition(mut points: Vec) -> G1Affine { @@ -42,6 +48,66 @@ pub fn batch_addition(mut points: Vec) -> G1Affine { points[0] } + +// top down balanced tree idea - benedikt +// search tree for sorted array +pub fn batch_addition_diff_stride(mut points: Vec) -> G1Projective { + #[inline(always)] + fn point_add(p1: G1Affine, p2: G1Affine, inv: &blstrs::Fp) -> G1Affine { + use ff::Field; + + let lambda = (p2.y() - p1.y()) * inv; + let x = lambda.square() - p1.x() - p2.x(); + let y = lambda * (p1.x() - x) - p1.y(); + G1Affine::from_raw_unchecked(x, y, false) + } + + if points.is_empty() { + use group::prime::PrimeCurveAffine; + use group::Group; + return G1Projective::identity(); + } + + let mut new_differences = Vec::with_capacity(points.len()); + + let mut points_len = points.len(); + + let mut sum = G1Projective::identity(); + + const BATCH_INVERSE_THRESHOLD: usize = 16; + + while points.len() > BATCH_INVERSE_THRESHOLD { + if points.len() % 2 != 0 { + sum += points + .pop() + .expect("infallible; since points has an odd length"); + } + new_differences.clear(); + + for i in (0..=points.len() - 2).step_by(2) { + new_differences.push(points[i + 1].x() - points[i].x()); + } + + batch_inverse(&mut new_differences); + // + for (i, inv) in (0..=points.len() - 2).step_by(2).zip(&new_differences) { + let p1 = points[i]; + let p2 = points[i + 1]; + points[i / 2] = point_add(p1, p2, inv); + } + + // The latter half of the vector is now unused, + // all results are stored in the former half. + points.truncate(new_differences.len()) + } + + for point in points { + sum += point + } + + sum +} + // This method assumes that adjacent points are not the same // This will lead to an inversion by zero pub fn batch_addition_mut(points: &mut [G1Affine]) -> G1Affine { @@ -88,10 +154,10 @@ pub fn batch_addition_mut(points: &mut [G1Affine]) -> G1Affine { // TODO so we want to check if it makes a difference in our usecase. pub fn multi_batch_addition(mut multi_points: Vec>) -> Vec { #[inline(always)] - fn point_add_double(p1: G1Affine, p2: G1Affine, inv: &blstrs::Fp) -> G1Affine { + fn point_add_double(p1: &mut G1Affine, p2: G1Affine, inv: &blstrs::Fp) { use ff::Field; - let lambda = if p1 == p2 { + let lambda = if *p1 == p2 { p1.x().square().mul3() * inv } else { (p2.y() - p1.y()) * inv @@ -99,7 +165,8 @@ pub fn multi_batch_addition(mut multi_points: Vec>) -> Vec>) -> Vec>) -> Vec>) -> Vec { + #[inline(always)] + fn point_add_double(p1: G1Affine, p2: G1Affine, inv: &blstrs::Fp) -> G1Affine { + use ff::Field; + + let lambda = if p1 == p2 { + p1.x().square().mul3() * inv + } else { + (p2.y() - p1.y()) * inv + }; + + let x = lambda.square() - p1.x() - p2.x(); + let y = lambda * (p1.x() - x) - p1.y(); + + G1Affine::from_raw_unchecked(x, y, false) + } + #[inline(always)] + // Note: We do not handle the case where p1 == -p2 + fn choose_add_or_double(p1: G1Affine, p2: G1Affine) -> Fp { + use ff::Field; + + if p1 == p2 { + p2.y().double() + } else { + p1.x() - p2.x() + } + } + + let total_num_points: usize = multi_points.iter().map(|p| p.len()).sum(); + let mut scratchpad = Vec::with_capacity(total_num_points); + + // Find the largest buckets, this will be the bottleneck for the number of iterations + let mut max_bucket_length = 0; + for points in multi_points.iter() { + max_bucket_length = std::cmp::max(max_bucket_length, points.len()); + } + + // Compute the total number of "unit of work" + // In the single batch addition case this is analogous to + // the batch inversion threshold + #[inline(always)] + fn compute_threshold(points: &[Vec]) -> usize { + points + .iter() + .map(|p| { + if p.len() % 2 == 0 { + p.len() / 2 + } else { + (p.len() - 1) / 2 + } + }) + .sum() + } + + let mut new_differences = Vec::with_capacity(max_bucket_length); + let mut total_amount_of_work = compute_threshold(&multi_points); + + let mut sums = vec![G1Projective::identity(); multi_points.len()]; + + // TODO: total_amount_of_work does not seem to be changing performance that much + while total_amount_of_work > 16 { + // For each point, we check if they are odd and pop off + // one of the points + for (points, sum) in multi_points.iter_mut().zip(sums.iter_mut()) { + // Make the points even + if points.len() % 2 != 0 { + *sum += points.pop().unwrap(); + } + } + + new_differences.clear(); + + // For each pair of points over all + // vectors, we collect them and put them in the + // inverse array + for points in multi_points.iter() { + if points.len() < 2 { + continue; + } + for i in (0..=points.len() - 2).step_by(2) { + // new_differences.push(points[i + 1].x() - points[i].x()); + new_differences.push(choose_add_or_double(points[i], points[i + 1])); + } + } + + batch_inverse_scratch_pad(&mut new_differences, &mut scratchpad); + // new_differences.reverse(); + + let mut new_differences_offset = 0; + + for points in multi_points.iter_mut() { + if points.len() < 2 { + continue; + } + for (i, inv) in (0..=points.len() - 2) + .step_by(2) + // .zip(new_differences.iter().rev()) + .zip(&new_differences[new_differences_offset..]) + { + let p1 = points[i]; + let p2 = points[i + 1]; + points[i / 2] = point_add_double(p1, p2, inv); + } + + let num_points = points.len() / 2; + // The latter half of the vector is now unused, + // all results are stored in the former half. + points.truncate(num_points); + // new_differences = new_differences[num_points..].to_vec(); + new_differences_offset += num_points + } + + total_amount_of_work = compute_threshold(&multi_points); + } + + for (sum, points) in sums.iter_mut().zip(multi_points) { + for point in points { + *sum += point + } + } + + sums +} + #[cfg(test)] mod tests { + use crate::batch_add::{batch_addition_diff_stride, multi_batch_addition_diff_stride}; + use super::{batch_addition, multi_batch_addition}; use blstrs::{G1Affine, G1Projective}; use group::Group; #[test] fn test_batch_addition() { - let num_points = 100; + let num_points = 101; let points: Vec = (0..num_points) .map(|_| G1Projective::random(&mut rand::thread_rng()).into()) .collect(); @@ -204,8 +397,8 @@ mod tests { .fold(G1Projective::identity(), |acc, p| acc + p) .into(); - let got_result = batch_addition(points.clone()); - assert_eq!(expected_result, got_result); + let got_result = batch_addition_diff_stride(points.clone()); + assert_eq!(expected_result, got_result.into()); } #[test] @@ -229,4 +422,26 @@ mod tests { let got_results = multi_batch_addition(random_sets_of_points_clone); assert_eq!(got_results, expected_results); } + + #[test] + fn test_multi_batch_addition_diff_stride() { + let num_points = 99; + let num_sets = 5; + let random_sets_of_points: Vec> = (0..num_sets) + .map(|_| { + (0..num_points) + .map(|_| G1Projective::random(&mut rand::thread_rng()).into()) + .collect() + }) + .collect(); + let random_sets_of_points_clone = random_sets_of_points.clone(); + + let expected_results: Vec = random_sets_of_points + .into_iter() + .map(|points| batch_addition(points).into()) + .collect(); + + let got_results = multi_batch_addition_diff_stride(random_sets_of_points_clone); + assert_eq!(got_results, expected_results); + } } diff --git a/cryptography/bls12_381/src/batch_add_blst.rs b/cryptography/bls12_381/src/batch_add_blst.rs new file mode 100644 index 00000000..ca6f24a8 --- /dev/null +++ b/cryptography/bls12_381/src/batch_add_blst.rs @@ -0,0 +1,189 @@ +use blstrs::{Fp, G1Affine, G1Projective}; +use ff::Field; +use subtle::{Choice, ConditionallySelectable}; + +#[derive(Debug, Clone, Copy)] +pub struct G1AffineInv { + pub x: Fp, + pub y: Fp, + pub tmp: Fp, +} + +impl From for G1AffineInv { + fn from(value: G1Affine) -> Self { + Self { + x: value.x(), + y: value.y(), + tmp: Fp::ZERO, + } + } +} + +// fn is_zero(point: &G1AffineInv) -> Choice { +// point.x.is_zero() & point.y.is_zero() +// } + +// /* + +// COPIED FROM BLST!!!!!!!!!! + +// * This implementation uses explicit addition formula: +// * +// * λ = (Y₂-Y₁)/(X₂-X₁) +// * X₃ = λ²-(X₁+X₂) +// * Y₃ = λ⋅(X₁-X₃)-Y₁ +// * +// * But since we don't know if we'll have to add point to itself, we need +// * to eventually resort to corresponding doubling formula: +// * +// * λ = 3X₁²/2Y₁ +// * X₃ = λ²-2X₁ +// * Y₃ = λ⋅(X₁-X₃)-Y₁ +// * +// * The formulae use prohibitively expensive inversion, but whenever we +// * have a lot of affine points to accumulate, we can amortize the cost +// * by applying Montgomery's batch inversion approach. As a result, +// * asymptotic[!] per-point cost for addition is as small as 5M+1S. For +// * comparison, ptype##_dadd_affine takes 8M+5S. In practice, all things +// * considered, the improvement coefficient varies from 60% to 85% +// * depending on platform and curve. +// * +// * THIS IMPLEMENTATION IS *NOT* CONSTANT-TIME. [But if there is an +// * application that requires constant time-ness, speak up!] +// */ +// /* +// * Calculate λ's numerator and denominator. +// * +// * input: A x1 y1 - +// * B x2 y2 - +// * output: +// * if A!=B: A x1 y1 (x2-x1)*mul_acc +// * B x2+x1 y2-y1 (x2-x1) +// * +// * if A==B: A x y 2y*mul_acc +// * B 2x 3*x^2 2y +// * +// * if A==-B: A 0 0 1*mul_acc +// * B 0 3*x^2 0 +// */ +// fn head(a: &mut G1AffineInv, b: &mut G1AffineInv, mul_acc: Option<&Fp>) { +// let inf = is_zero(a) | is_zero(b); +// let zero = Fp::ZERO; +// let one = Fp::ONE; + +// // X2-X1 +// b.tmp = b.x - a.x; + +// // X2+X1 +// let x_sum = b.x + a.x; + +// // Y2+Y1 +// let y_sum = b.y + a.y; + +// // Y2-Y1 +// let y_diff = b.y - a.y; + +// if b.tmp.is_zero().into() { +// // X2==X1 +// let inf_inner = a.tmp.is_zero(); +// b.x = Fp::conditional_select(&b.x, &a.tmp, inf_inner); + +// // 3*X1^2 +// b.y = a.x.square(); +// b.y *= Fp::from(3u64); + +// // 2*Y1 +// b.tmp = a.tmp; +// } + +// // Conditional selections +// a.x = Fp::conditional_select(&a.x, &b.x, inf); +// a.y = Fp::conditional_select(&a.y, &a.tmp, inf); +// a.tmp = Fp::conditional_select(&one, &b.tmp, inf); +// b.tmp = Fp::conditional_select(&zero, &b.tmp, inf); + +// // Chain multiplication +// if let Some(acc) = mul_acc { +// a.tmp *= acc; +// } + +// // Update b +// b.x = x_sum; +// b.y = y_diff; +// } + +// fn tail(d: &mut G1AffineInv, ab: &[G1AffineInv; 2], mut lambda: Fp) { +// let a = &ab[0]; +// let b = &ab[1]; + +// let inf = b.tmp.is_zero(); +// let one = Fp::ONE; + +// // λ = (Y2-Y1)/(X2-X1) or 3*X1^2/2*Y1 +// lambda *= b.y; + +// // llambda = λ^2 +// let llambda = lambda.square(); + +// // X3 = λ^2 - (X2+X1) +// d.x = llambda - b.x; + +// // Y3 = λ*(X1-X3) - Y1 +// d.y = a.x - d.x; +// d.y *= lambda; +// d.y -= a.y; + +// // Conditional selection for point at infinity +// d.x = Fp::conditional_select(&d.x, &a.x, inf); +// d.y = Fp::conditional_select(&d.y, &a.y, inf); + +// // This seems to be +// // setting B->Z to 1 if the result is the point at infinity +// let mut b_tmp = Fp::conditional_select(&b.tmp, &one, inf); +// } + +// fn dadd_affine(sum: &mut G1Projective, point: &G1Affine) { +// sum.add_assign_mixed(point); +// } + +// fn accumulate(sum: &mut G1Projective, mut points: &mut [G1AffineInv]) { +// let mut n = points.len(); + +// while n >= 16 { +// if n & 1 != 0 { +// let affine_point = G1Affine::from_raw_unchecked(points[0].x, points[0].y, false); +// dadd_affine(sum, &affine_point); +// points = &mut points[1..]; +// n -= 1; +// } +// n /= 2; + +// let mut mul_acc = None; +// for i in 0..n { +// head(&mut points[2 * i], &mut points[2 * i + 1], mul_acc); +// mul_acc = Some(&points[2 * i].tmp); +// } + +// points[2 * n - 2].tmp = points[2 * n - 2].tmp.invert().unwrap(); + +// let mut dst = n; +// for i in (1..n).rev() { +// dst -= 1; +// points[2 * i - 2].tmp = points[2 * i - 2].tmp * points[2 * i].tmp; +// tail( +// &mut points[dst], +// &[points[2 * i - 2], points[2 * i - 1]], +// points[2 * i - 2].tmp, +// ); +// points[2 * i - 2].tmp = points[2 * i - 2].tmp * points[2 * i + 1].tmp; +// } +// dst -= 1; +// tail(&mut points[dst], &[points[0], points[1]], points[0].tmp); +// points = &mut points[..n]; +// } + +// for point in points.iter() { +// let affine_point = G1Affine::from_raw_unchecked(point.x, point.y, false); +// dadd_affine(sum, &affine_point); +// } +// } diff --git a/cryptography/bls12_381/src/fixed_base_msm.rs b/cryptography/bls12_381/src/fixed_base_msm.rs index e3e77150..d3668c62 100644 --- a/cryptography/bls12_381/src/fixed_base_msm.rs +++ b/cryptography/bls12_381/src/fixed_base_msm.rs @@ -1,4 +1,10 @@ -use crate::{fixed_base_msm_pippenger::FixedBaseMSMPippenger, G1Projective, Scalar}; +use crate::{ + fixed_base_msm_blst::FixedBaseMSMPrecompBLST, + fixed_base_msm_blst_all_windows::FixedBaseMSMPrecompAllWindow, + fixed_base_msm_pippenger::FixedBaseMSMPippenger, + limlee::{LimLee, TsaurChou}, + G1Projective, Scalar, +}; use blstrs::{Fp, G1Affine}; /// FixedBaseMSMPrecomp computes a multi scalar multiplication using pre-computations. @@ -27,7 +33,8 @@ pub enum UsePrecomp { /// of memory. #[derive(Debug)] pub enum FixedBaseMSM { - Precomp(FixedBaseMSMPrecomp), + Precomp(FixedBaseMSMPrecompAllWindow), + // Precomp(LimLee), // TODO: We are hijacking the NoPrecomp variant to store the // TODO: new pippenger algorithm. NoPrecomp(FixedBaseMSMPippenger), @@ -37,7 +44,10 @@ impl FixedBaseMSM { pub fn new(generators: Vec, use_precomp: UsePrecomp) -> Self { match use_precomp { UsePrecomp::Yes { width } => { - FixedBaseMSM::Precomp(FixedBaseMSMPrecomp::new(generators, width)) + FixedBaseMSM::Precomp(FixedBaseMSMPrecompAllWindow::new(&generators, width)) + // FixedBaseMSM::Precomp(FixedBaseMSMPrecompBLST::new(&generators, width)) + // FixedBaseMSM::Precomp(TsaurChou::new(8, 4, &generators)) + // FixedBaseMSM::Precomp(LimLee::new(8, 1, &generators)) } UsePrecomp::No => FixedBaseMSM::NoPrecomp(FixedBaseMSMPippenger::new(&generators)), } @@ -45,7 +55,11 @@ impl FixedBaseMSM { pub fn msm(&self, scalars: Vec) -> G1Projective { match self { - FixedBaseMSM::Precomp(precomp) => precomp.msm(scalars), + FixedBaseMSM::Precomp(precomp) => { + // TsaurChau + // precomp.mul_naive_better_wnaf_precomputations_final_msm(&scalars) + precomp.msm(&scalars) + } FixedBaseMSM::NoPrecomp(precomp) => precomp.msm(&scalars), } } diff --git a/cryptography/bls12_381/src/fixed_base_msm_blst.rs b/cryptography/bls12_381/src/fixed_base_msm_blst.rs new file mode 100644 index 00000000..114f3602 --- /dev/null +++ b/cryptography/bls12_381/src/fixed_base_msm_blst.rs @@ -0,0 +1,254 @@ +use std::time::{Duration, Instant}; + +use crate::{ + batch_add::{ + batch_addition, batch_addition_diff_stride, multi_batch_addition, + multi_batch_addition_diff_stride, + }, + booth_encoding::{self, get_booth_index}, + fixed_base_msm_pippenger::FixedBaseMSMPippenger, + g1_batch_normalize, G1Projective, Scalar, +}; +use blstrs::{Fp, G1Affine}; +use ff::PrimeField; +use group::prime::PrimeCurveAffine; +use group::Group; + +// FixedBasePrecomp blst way with some changes +#[derive(Debug)] +pub struct FixedBaseMSMPrecompBLST { + table: Vec>, // TODO: Make this a Vec<> and then just do the maths in msm function for offsetting + // table: Vec, + wbits: usize, +} + +impl FixedBaseMSMPrecompBLST { + pub fn new(points: &[G1Affine], wbits: usize) -> Self { + // For every point `P`, wbits indicates that we should compute + // 1 * P, ..., (2^{wbits} - 1) * P + // + // The total amount of memory is roughly (numPoints * 2^{wbits} - 1) + // where each point is 64 bytes. + // + // + let mut precomputed_points: Vec<_> = points + .into_iter() + .map(|point| Self::precompute_points(wbits, *point)) + .collect(); + + Self { + table: precomputed_points, + wbits, + } + } + // Given a point, we precompute P,..., (2^w -1) * P + fn precompute_points(wbits: usize, point: G1Affine) -> Vec { + let mut lookup_table = Vec::with_capacity(1 << (wbits - 1)); + + // Convert to projective for faster operations + let mut current = G1Projective::from(point); + + // Compute and store multiples + for _ in 0..(1 << (wbits - 1)) { + lookup_table.push(current); + current += point; + } + + g1_batch_normalize(&lookup_table) + } + + pub fn msm(&self, scalars: &[Scalar]) -> G1Projective { + let scalars_bytes: Vec<_> = scalars.iter().map(|a| a.to_bytes_le()).collect(); + let number_of_windows = Scalar::NUM_BITS as usize / self.wbits + 1; + + let mut windows_of_points = vec![Vec::with_capacity(scalars.len()); number_of_windows]; + + for window_idx in 0..number_of_windows { + for (scalar_idx, scalar_bytes) in scalars_bytes.iter().enumerate() { + let sub_table = &self.table[scalar_idx]; + let point_idx = get_booth_index(window_idx, self.wbits, scalar_bytes.as_ref()); + + if point_idx == 0 { + continue; + } + let sign = point_idx.is_positive(); + let point_idx = point_idx.unsigned_abs() as usize - 1; + let mut point = sub_table[point_idx]; + // let mut point = self.table[(scalar_idx * 1 << self.wbits) + point_idx]; + if !sign { + point = -point; + } + + windows_of_points[window_idx].push(point); + } + } + + // For each window, lets add all of the points together. + // let accumulated_points: Vec<_> = windows_of_points + // .into_iter() + // .map(|wp| batch_addition_diff_stride(wp)) + // .collect(); + let accumulated_points = multi_batch_addition_diff_stride(windows_of_points); + + // Now accumulate the windows by doubling wbits times + let mut result: G1Projective = *accumulated_points.last().unwrap(); + for point in accumulated_points.into_iter().rev().skip(1) { + // Double the result 'wbits' times + for _ in 0..self.wbits { + result = result.double(); + } + // Add the accumulated point for this window + result += point; + } + + result + } + + // Returns all of the unreduced windows + fn partial_msm_part1(&self, scalars: &[Scalar]) -> Vec> { + let scalars_bytes: Vec<_> = scalars.iter().map(|a| a.to_bytes_le()).collect(); + let number_of_windows = Scalar::NUM_BITS as usize / self.wbits + 1; + + let mut windows_of_points = vec![Vec::with_capacity(scalars.len()); number_of_windows]; + + for window_idx in 0..number_of_windows { + for (scalar_idx, scalar_bytes) in scalars_bytes.iter().enumerate() { + let sub_table = &self.table[scalar_idx]; + let point_idx = get_booth_index(window_idx, self.wbits, scalar_bytes.as_ref()); + + if point_idx == 0 { + continue; + } + let sign = point_idx.is_positive(); + let point_idx = point_idx.unsigned_abs() as usize - 1; + let mut point = sub_table[point_idx]; + // let mut point = self.table[(scalar_idx * 1 << self.wbits) + point_idx]; + if !sign { + point = -point; + } + + windows_of_points[window_idx].push(point); + } + } + + windows_of_points + } + + fn partial_msm_part2(&self, accumulated_points: &[G1Projective]) -> G1Projective { + // Now accumulate the windows by doubling wbits times + let mut result: G1Projective = *accumulated_points.last().unwrap(); + for point in accumulated_points.into_iter().rev().skip(1) { + // Double the result 'wbits' times + for _ in 0..self.wbits { + result = result.double(); + } + // Add the accumulated point for this window + result += point; + } + + result + } + + fn partial_msm_part2_affine(&self, accumulated_points: &[G1Affine]) -> G1Projective { + // Now accumulate the windows by doubling wbits times + let mut result: G1Projective = G1Projective::from(*accumulated_points.last().unwrap()); + for point in accumulated_points.into_iter().rev().skip(1) { + // Double the result 'wbits' times + for _ in 0..self.wbits { + result = result.double(); + } + // Add the accumulated point for this window + result += point; + } + + result + } +} + +#[derive(Debug)] +pub struct FixedBaseMultiMSMPrecompBLST { + msms: Vec, +} + +impl FixedBaseMultiMSMPrecompBLST { + pub fn new(generator_sets: Vec>, wbits: usize) -> Self { + let msms: Vec<_> = generator_sets + .into_iter() + .map(|generators| FixedBaseMSMPrecompBLST::new(&generators, wbits)) + .collect(); + Self { msms } + } + + pub fn multi_msm(&self, scalar_sets: Vec>) -> Vec { + let num_results = scalar_sets.len(); + + let number_of_windows = Scalar::NUM_BITS as usize / self.msms[0].wbits + 1; + + let multiple_windows: Vec<_> = scalar_sets + .into_iter() + .zip(&self.msms) + .flat_map(|(scalars, msm)| msm.partial_msm_part1(&scalars)) + .collect(); + + let accumulated_points = multi_batch_addition_diff_stride(multiple_windows); + + let mut results = Vec::with_capacity(num_results); + + for (set_of_windows, msm) in accumulated_points + .chunks_exact(number_of_windows) + .into_iter() + .zip(&self.msms) + { + results.push(msm.partial_msm_part2(set_of_windows)); + } + + results + } +} + +#[test] +fn precomp_lookup_table() { + use group::Group; + let lookup_table = FixedBaseMSMPrecompBLST::precompute_points(7, G1Affine::generator()); + + for i in 1..lookup_table.len() { + let expected = G1Projective::generator() * Scalar::from((i + 1) as u64); + assert_eq!(lookup_table[i], expected.into(),) + } +} +use ff::Field; + +#[test] +fn msm_blst_precomp() { + let length = 64; + let generators: Vec<_> = (0..length) + .map(|_| G1Projective::random(&mut rand::thread_rng()).into()) + .collect(); + let scalars: Vec<_> = (0..length) + .map(|_| Scalar::random(&mut rand::thread_rng())) + .collect(); + + let res = crate::lincomb::g1_lincomb(&generators, &scalars) + .expect("number of generators and number of scalars is equal"); + + let fbm = FixedBaseMSMPrecompBLST::new(&generators, 7); + let result = fbm.msm(&scalars); + + assert_eq!(res, result); +} + +#[test] +fn bench_window_sizes_msm() { + let length = 64; + let generators: Vec<_> = (0..length) + .map(|_| G1Projective::random(&mut rand::thread_rng()).into()) + .collect(); + let scalars: Vec<_> = (0..length) + .map(|_| Scalar::random(&mut rand::thread_rng())) + .collect(); + + for i in 2..=14 { + let fbm = FixedBaseMSMPrecompBLST::new(&generators, i); + fbm.msm(&scalars); + } +} diff --git a/cryptography/bls12_381/src/fixed_base_msm_blst_all_windows.rs b/cryptography/bls12_381/src/fixed_base_msm_blst_all_windows.rs new file mode 100644 index 00000000..750dddad --- /dev/null +++ b/cryptography/bls12_381/src/fixed_base_msm_blst_all_windows.rs @@ -0,0 +1,166 @@ +use crate::{ + batch_add::{batch_addition, batch_addition_diff_stride}, + booth_encoding::get_booth_index, + g1_batch_normalize, G1Projective, Scalar, +}; +use rayon::prelude::*; + +use blstrs::G1Affine; +use ff::{Field, PrimeField}; +use group::Group; + +// Note: This is the same strategy that blst uses +#[derive(Debug)] +pub struct FixedBaseMSMPrecompAllWindow { + tables: Vec>, + window_size: usize, + num_windows: usize, +} + +impl FixedBaseMSMPrecompAllWindow { + pub fn new(points: &[G1Affine], window_size: usize) -> Self { + let num_windows = Scalar::NUM_BITS as usize / window_size + 1; + + let precomputed_points: Vec<_> = points + .iter() + .map(|point| Self::precompute_points(window_size, num_windows, *point)) + .collect(); + + Self { + tables: precomputed_points, + window_size, + num_windows, + } + } + + fn precompute_points( + window_size: usize, + number_of_windows: usize, + point: G1Affine, + ) -> Vec { + let window_size_scalar = Scalar::from(1 << window_size); + + use rayon::prelude::*; + + let all_tables: Vec<_> = (0..number_of_windows) + .into_par_iter() + .flat_map(|window_index| { + let window_scalar = window_size_scalar.pow(&[window_index as u64]); + let mut lookup_table = Vec::with_capacity(1 << (window_size - 1)); + let point = G1Projective::from(point) * window_scalar; + let mut current = point; + // Compute and store multiples + for _ in 0..(1 << (window_size - 1)) { + lookup_table.push(current); + current += point; + } + g1_batch_normalize(&lookup_table) + }) + .collect(); + + all_tables + } + + // Given a point, we precompute P,..., (2^{w-1}-1) * P + // fn precompute_points(wbits: usize, point: G1Affine) -> Vec { + // let mut lookup_table = Vec::with_capacity(1 << (wbits - 1)); + + // // Convert to projective for faster operations + // let mut current = G1Projective::from(point); + + // // Compute and store multiples + // for _ in 0..(1 << (wbits - 1)) { + // lookup_table.push(current); + // current += point; + // } + + // g1_batch_normalize(&lookup_table) + // } + + pub fn msm(&self, scalars: &[Scalar]) -> G1Projective { + let scalars_bytes: Vec<_> = scalars.iter().map(|a| a.to_bytes_le()).collect(); + + let mut points_to_add = Vec::new(); + + for window_idx in 0..self.num_windows { + for (scalar_idx, scalar_bytes) in scalars_bytes.iter().enumerate() { + let sub_table = &self.tables[scalar_idx]; + let point_idx = + get_booth_index(window_idx, self.window_size, scalar_bytes.as_ref()); + + if point_idx == 0 { + continue; + } + let sign = point_idx.is_positive(); + let point_idx = point_idx.unsigned_abs() as usize - 1; + + // Scale the point index by the window index to figure out whether + // we need P, 2^wP, 2^{2w}P, etc + let scaled_point_index = window_idx * (1 << (self.window_size - 1)) + point_idx; + let mut point = sub_table[scaled_point_index]; + + if !sign { + point = -point; + } + + points_to_add.push(point); + } + } + + batch_addition(points_to_add).into() + } +} + +#[cfg(test)] +mod all_windows_tests { + use super::*; + use ff::Field; + use group::prime::PrimeCurveAffine; + + #[test] + fn precomp_lookup_table() { + use group::Group; + let lookup_table = + FixedBaseMSMPrecompAllWindow::precompute_points(7, 1, G1Affine::generator()); + + for i in 1..lookup_table.len() { + let expected = G1Projective::generator() * Scalar::from((i + 1) as u64); + assert_eq!(lookup_table[i], expected.into(),) + } + } + + #[test] + fn msm_blst_precomp() { + let length = 64; + let generators: Vec<_> = (0..length) + .map(|_| G1Projective::random(&mut rand::thread_rng()).into()) + .collect(); + let scalars: Vec<_> = (0..length) + .map(|_| Scalar::random(&mut rand::thread_rng())) + .collect(); + + let res = crate::lincomb::g1_lincomb(&generators, &scalars) + .expect("number of generators and number of scalars is equal"); + + let fbm = FixedBaseMSMPrecompAllWindow::new(&generators, 7); + let result = fbm.msm(&scalars); + + assert_eq!(res, result); + } + + #[test] + fn bench_window_sizes_msm() { + let length = 64; + let generators: Vec<_> = (0..length) + .map(|_| G1Projective::random(&mut rand::thread_rng()).into()) + .collect(); + let scalars: Vec<_> = (0..length) + .map(|_| Scalar::random(&mut rand::thread_rng())) + .collect(); + + for i in 2..=14 { + let fbm = FixedBaseMSMPrecompAllWindow::new(&generators, i); + fbm.msm(&scalars); + } + } +} diff --git a/cryptography/bls12_381/src/fixed_base_msm_pippenger.rs b/cryptography/bls12_381/src/fixed_base_msm_pippenger.rs index 8e59e653..992accec 100644 --- a/cryptography/bls12_381/src/fixed_base_msm_pippenger.rs +++ b/cryptography/bls12_381/src/fixed_base_msm_pippenger.rs @@ -1,9 +1,11 @@ use std::collections::HashSet; +use std::thread::current; use blstrs::G1Affine; use blstrs::G1Projective; use blstrs::Scalar; use ff::PrimeField; +use group::prime::PrimeCurveAffine; use group::Group; use crate::booth_encoding::get_booth_index; @@ -18,7 +20,7 @@ pub struct FixedBaseMSMPippenger { impl FixedBaseMSMPippenger { pub fn new(points: &[G1Affine]) -> FixedBaseMSMPippenger { // The +2 was empirically seen to give better results - let window_size = (f64::from(points.len() as u32)).ln().ceil() as usize + 2; + let window_size = 8; let number_of_windows = Scalar::NUM_BITS as usize / window_size + 1; let precomputed_points = precompute(window_size, number_of_windows, points); diff --git a/cryptography/bls12_381/src/lib.rs b/cryptography/bls12_381/src/lib.rs index 7d349584..5575386c 100644 --- a/cryptography/bls12_381/src/lib.rs +++ b/cryptography/bls12_381/src/lib.rs @@ -1,9 +1,16 @@ -mod batch_add; +pub mod batch_add; +pub mod batch_add_blst; pub mod batch_inversion; mod booth_encoding; pub mod fixed_base_msm; +pub mod fixed_base_msm_blst; +pub mod fixed_base_msm_blst_all_windows; pub mod fixed_base_msm_pippenger; +pub mod limlee; pub mod lincomb; +pub mod seokim; +pub mod simple_msm; +pub mod wnaf; // Re-export ff and group, so other crates do not need to directly import(and independently version) them pub use ff; diff --git a/cryptography/bls12_381/src/limlee.rs b/cryptography/bls12_381/src/limlee.rs new file mode 100644 index 00000000..aa60ba45 --- /dev/null +++ b/cryptography/bls12_381/src/limlee.rs @@ -0,0 +1,1110 @@ +use core::num; + +use blstrs::{Fp, G1Affine, G1Projective, Scalar}; +use ff::{Field, PrimeField}; +use group::{prime::PrimeCurveAffine, Group, WnafScalar}; + +use crate::{ + batch_add::{batch_addition, multi_batch_addition, multi_batch_addition_diff_stride}, + g1_batch_normalize, + wnaf::wnaf_form, +}; + +// Reference: http://mhutter.org/papers/Mohammed2012ImprovedFixedBase.pdf +// +// For now I will use the variables used in the paper, and then we can +// rename them to be more descriptive. +#[derive(Debug, Clone)] +pub struct LimLee { + l: u32, + // For a scalar with `l` bits, + // We choose a splitting parameter `h` such that + // the `l` bits of the scalar is split into `a = l/h` bits + // + h: u32, + // The scalars bits are grouped into `a` bit-groups of size `a` + a: u32, + // For the `bit-groups of size `a`, + // We choose a splitting parameter `v` such that + // the `a` bits are split into `b = a / v` bits + v: u32, + // + b: u32, + // + precomputed_points: Vec>>, +} + +impl LimLee { + pub fn new(h: u32, v: u32, points: &[G1Affine]) -> LimLee { + // Compute `a`. + + // TODO: Add one so that we view all scalars as 256 bit numbers. + // We can modify it to view everything as 255 bits with a tiny bit of refactoring + // when we pad the decomposed scalar + let l = Self::compute_padded_scalar_bits(Scalar::NUM_BITS + 1, h); + // First of all check that h < l + assert!(h < l); + let a = l.div_ceil(h); + + assert!(v <= a); + assert!( + a % v == 0, + "v must be a factor of a, so that b can be equally sized v={v}, a={a}", + ); + // Compute `b` + let b = a.div_ceil(v); + + let mut ll = LimLee { + h, + a, + v, + b, + precomputed_points: Vec::new(), + l, + }; + use rayon::prelude::*; + let precomputed = points + .into_par_iter() + .map(|point| ll.precompute_point(*point)) + .collect(); + ll.precomputed_points = precomputed; + + ll + } + + // we want to compute scalar_size / divider but pad by zeroes + // if the scalar_size does not divide 'divider' + // + // This method returns the padded size of the scalar + fn compute_padded_scalar_bits(scalar_size: u32, divider: u32) -> u32 { + scalar_size.div_ceil(divider) * divider + } + + // This corresponds to the naive sum in 3.1 where there is no pre-computation + // and P is the generator + pub fn scalar_mul_naive(&self, scalar: Scalar) -> G1Projective { + dbg!(&self); + let mut scalar_bits = scalar_to_bits(scalar).to_vec(); + + // Pad the scalar, if the value of `l` necesitates it + scalar_bits.extend(vec![0u8; self.l as usize - scalar_bits.len()]); // 256 here because we convert to bytes and then bits + + // Group the scalar bits into `a` chunks + assert!(scalar_bits.len() as u32 % self.b == 0); + let mut b_chunks: Vec<_> = scalar_bits.chunks_exact(self.b as usize).collect(); + let scalar_bits: Vec<_> = b_chunks.into_iter().map(|b| bits_to_byte(b)).collect(); + + // For the columns + let mut result = G1Projective::identity(); + + for j in 0..self.v { + for i in 0..self.h { + // We use a flat array, but the algorithm + // is based off of a matrix, so compute the flattened index + let index = i * self.v + j; + let digit = scalar_bits[index as usize]; + + let exponent = j * self.b + i * self.a; + let mut tmp = G1Projective::generator(); + for _ in 0..exponent { + tmp = tmp.double(); + } + result += tmp * Scalar::from(digit as u64); + } + } + + result + } + + // This corresponds to equation 3 on page 347 + pub fn scalar_mul_eq3(&self, scalar: Scalar) -> G1Projective { + let mut scalar_bits = scalar_to_bits(scalar).to_vec(); + + // Pad the scalar, if the value of `l` necessitates it + scalar_bits.extend(vec![0u8; self.l as usize - 256]); // 256 here because we convert to bytes and then bits + + // Group the scalar bits into `a` chunks + assert!(scalar_bits.len() as u32 % self.b == 0); + let mut b_chunks: Vec<_> = scalar_bits.chunks_exact(self.b as usize).collect(); + let scalar_bits: Vec<_> = b_chunks.into_iter().map(|b| bits_to_byte(b)).collect(); + + // Precomputations + let mut precomputations = Vec::new(); + precomputations.push(G1Projective::generator()); + for i in 0..self.h { + let two_pow_a = Scalar::from(2u64).pow(&[self.a as u64]); + precomputations.push(precomputations.last().unwrap() * two_pow_a); + } + + // For the columns + let mut result = G1Projective::identity(); + + for j in 0..self.v { + for i in 0..self.h { + // We use a flat array, but the algorithm + // is based off of a matrix, so compute the flattened index + let index = i * self.v + j; + let digit = scalar_bits[index as usize]; + + let exponent = j * self.b; + let mut tmp = precomputations[i as usize]; + for _ in 0..exponent { + tmp = tmp.double(); + } + result += tmp * Scalar::from(digit as u64); + } + } + + result + } + + // This corresponds to eq4 on page 347 + pub fn scalar_mul_eq4(&self, scalar: Scalar) -> G1Projective { + let mut scalar_bits = scalar_to_bits(scalar).to_vec(); + + // Pad the scalar, if the value of `l` necessitates it + scalar_bits.extend(vec![0u8; self.l as usize - 256]); // 256 here because we convert to bytes and then bits + + // Precomputations + let mut precomputations = Vec::new(); + precomputations.push(G1Projective::generator()); + for i in 0..self.h { + let two_pow_a = Scalar::from(2u64).pow(&[self.a as u64]); + precomputations.push(precomputations.last().unwrap() * two_pow_a); + } + + let mut result = G1Projective::identity(); + // For the columns + + for t in 0..self.b { + let mut double_inner_sum = G1Projective::identity(); + for j in 0..self.v { + for i in 0..self.h { + // We use a flat array, but the algorithm + // is based off of a matrix, so compute the flattened index + let index = i * self.v * self.b + j * self.b + t; + let digit = scalar_bits[index as usize]; + + let exponent = j * self.b; + let mut tmp = precomputations[i as usize]; + for _ in 0..exponent { + tmp = tmp.double(); + } + double_inner_sum += tmp * Scalar::from(digit as u64); + } + } + + for _ in 0..t { + double_inner_sum = double_inner_sum.double() + } + result += double_inner_sum; + } + + result + } + + // This corresponds to eq5 on page 347 + pub fn scalar_mul_eq5(&self, scalar: Scalar) -> G1Projective { + let mut scalar_bits = scalar_to_bits(scalar).to_vec(); + + // Pad the scalar, if the value of `l` necessitates it + scalar_bits.extend(vec![0u8; self.l as usize - 256]); // 256 here because we convert to bytes and then bits + + // Precomputations + let mut precomputations = Vec::new(); + precomputations.push(G1Projective::generator()); + for i in 0..self.h { + let two_pow_a = Scalar::from(2u64).pow(&[self.a as u64]); + precomputations.push(precomputations.last().unwrap() * two_pow_a); + } + + let mut g_s = + vec![vec![G1Projective::identity(); (1 << self.h) as usize]; (self.v as usize)]; + + // Initialize the j==0 case + // Compute G[0][s] for all s + for s in 1..(1 << self.h) { + let mut g0s = G1Projective::identity(); + for i in 0..self.h { + if (s & (1 << i)) != 0 { + g0s += precomputations[i as usize]; + } + } + g_s[0][s] = g0s; + } + + // Compute G[j][s] for j > 0 + let two_pow_b = Scalar::from(2u64).pow(&[self.b as u64]); + for j in 1..self.v as usize { + for s in 1..(1 << self.h) as usize { + g_s[j][s] = g_s[j - 1][s] * two_pow_b; + } + } + + let g_s: Vec<_> = g_s + .into_iter() + .map(|g_s_i| g1_batch_normalize(&g_s_i)) + .collect(); + + let mut total_len = 0; + for g in &g_s { + total_len += g.len() + } + dbg!(total_len); + + let mut result = G1Projective::identity(); + for t in 0..self.b { + let mut double_inner_sum = G1Projective::identity(); + for j in 0..self.v { + let i_jt = self.compute_i_jt(&scalar_bits, j, t); + if i_jt != 0 { + double_inner_sum += g_s[j as usize][i_jt]; + } + } + + for _ in 0..t { + double_inner_sum = double_inner_sum.double() + } + result += double_inner_sum; + } + result + } + + pub fn msm(&self, scalars: &[Scalar]) -> G1Projective { + // Convert scalars to bits + // let now = std::time::Instant::now(); + let scalars_bits: Vec<_> = scalars + .into_iter() + .map(|scalar| { + let mut scalar_bits = scalar_to_bits(*scalar).to_vec(); + scalar_bits.extend(vec![0u8; self.l as usize - scalar_bits.len()]); + scalar_bits + }) + .collect(); + // dbg!("scalar conversion", now.elapsed().as_micros()); + let mut window: Vec> = vec![vec![]; self.b as usize]; + + // let now = std::time::Instant::now(); + for (scalar_index, scalar_bits) in scalars_bits.iter().enumerate() { + for t in (0..self.b) { + for j in 0..self.v { + let i_jt = self.compute_i_jt(&scalar_bits, j, t); + if i_jt != 0 { + window[t as usize] + .push(self.precomputed_points[scalar_index][j as usize][i_jt]); + } + } + } + } + + let mut result = G1Projective::identity(); + let summed_windows = multi_batch_addition_diff_stride(window); + + for (window) in summed_windows.into_iter().rev() { + result = result.double(); + result += window; + } + + // dbg!(now.elapsed().as_micros()); + result + } + + fn num_precomputed_points(&self) -> usize { + let mut total = 0; + for set_of_points in &self.precomputed_points { + for row in set_of_points { + total += row.len(); + } + } + total + } + + fn precompute_point(&self, point: G1Affine) -> Vec> { + let point = G1Projective::from(point); + // Precomputations + let mut precomputations = Vec::new(); + precomputations.push(point); + for i in 0..self.h { + let two_pow_a = Scalar::from(2u64).pow(&[self.a as u64]); + precomputations.push(precomputations.last().unwrap() * two_pow_a); + } + + let mut g_s = + vec![vec![G1Projective::identity(); (1 << self.h) as usize]; (self.v as usize)]; + + // Initialize the j==0 case + // Compute G[0][s] for all s + for s in 1..(1 << self.h) { + let mut g0s = G1Projective::identity(); + for i in 0..self.h { + if (s & (1 << i)) != 0 { + g0s += precomputations[i as usize]; + } + } + g_s[0][s] = g0s; + } + + // Compute G[j][s] for j > 0 + let two_pow_b = Scalar::from(2u64).pow(&[self.b as u64]); + for j in 1..self.v as usize { + for s in 1..(1 << self.h) as usize { + g_s[j][s] = g_s[j - 1][s] * two_pow_b; + } + } + + let g_s: Vec<_> = g_s + .into_iter() + .map(|g_s_i| g1_batch_normalize(&g_s_i)) + .collect(); + g_s + } + + fn compute_i_jt(&self, k: &[u8], j: u32, t: u32) -> usize { + let mut i_jt = 0; + for i in 0..self.h { + let bit_index = (i * self.v * self.b + j * self.b + t) as usize; + if bit_index < k.len() && (k[bit_index] == 1) { + i_jt |= 1 << i; + } + } + i_jt as usize + } +} + +type PrecomputedPoints = Vec>; + +#[derive(Debug)] +pub struct TsaurChou { + // These are not the same as LimLee + // + // + omega: usize, + v: usize, + a: usize, + b: usize, + num_bits: usize, + + precomputed_points: Vec, +} + +impl TsaurChou { + pub fn new(omega: usize, v: usize, points: &[G1Affine]) -> TsaurChou { + let num_bits = Scalar::NUM_BITS + 1; + + // This is the padded number of bits needed to make sure division + // by omega is exact. + let num_bits = Self::calculate_padded_size(num_bits as usize, omega); + + let a = num_bits / omega; + + // assert!(a % v == 0, "a={} v={}", a, v); + let b = a.div_ceil(v); + + let mut precomputed_points = Vec::new(); + for point in points { + precomputed_points.push(Self::precompute_point(*point, omega, b, v)) + } + + Self { + omega, + v, + a, + b, + num_bits, + precomputed_points, + } + } + + fn calculate_padded_size(l: usize, w: usize) -> usize { + let a = (l + w - 1) / w; // This is ⌈l/ω⌉ + let padded_size = a * w; + // TODO: if statement not needed, if we do div_ceil + let padding_zeros = if l % w == 0 { 0 } else { padded_size - l }; + padding_zeros + l + } + + fn num_precomputed_points(&self) -> usize { + let mut result = 0; + for points in &self.precomputed_points { + for p in points.iter() { + result += p.len() + } + } + result + } + + // On page350, this is the first double summation + pub fn mul_naive(&self, scalar: &Scalar) -> G1Projective { + // Convert scalar to wnaf + // let mut wnaf_digits = vec![]; + // wnaf_form(&mut wnaf_digits, scalar.to_repr(), self.omega); + let mut wnaf_digits = scalar_to_bits(*scalar).to_vec(); + wnaf_digits.extend(vec![0u8; self.num_bits - wnaf_digits.len()]); + let point = G1Projective::generator(); + let mut result = G1Projective::identity(); + + // 1. Compute the precomputations + + // 2. iterate `w` bits and compute the scalar_mul + for j in 0..self.v { + for t in 0..self.b { + // Choose K_jb+t + let exponent = t * self.omega + j * self.b * self.omega; + let two_pow_exponent = Scalar::from(2u64).pow(&[exponent as u64]); + + // Index K_jb+t + let start_index = (j * self.b + t) * self.omega; + let end_index = start_index + self.omega; + let k_jbt = &wnaf_digits[start_index..end_index.min(wnaf_digits.len())]; + // Convert K_jb+t from NAF to scalar + let mut digit = Scalar::ZERO; + for (i, &bit) in k_jbt.iter().enumerate() { + if bit > 0 { + digit += Scalar::from(bit as u64) * Scalar::from(2u64).pow(&[i as u64]); + } else if bit < 0 { + digit += -Scalar::from(bit as u64) * Scalar::from(2u64).pow(&[i as u64]); + } + } + + result += point * digit * two_pow_exponent; + } + } + + result + } + + // On page 350, this is the second summation. next to the first one + // under the matrix. Where we pull out a 2^tw + pub fn mul_naive_better(&self, scalar: &Scalar) -> G1Projective { + // Convert scalar to wnaf + // let mut wnaf_digits = vec![]; + // wnaf_form(&mut wnaf_digits, scalar.to_repr(), self.omega); + let mut wnaf_digits = scalar_to_bits(*scalar).to_vec(); + wnaf_digits.extend(vec![0u8; self.num_bits - wnaf_digits.len()]); + let point = G1Projective::generator(); + let mut result = G1Projective::identity(); + // TODO: I think we need to pad here after wnaf + + // 1. Compute the precomputations + + // 2. iterate `w` bits and compute the scalar_mul + for t in 0..self.b { + let two_pow_tw = Scalar::from(2u64).pow(&[(t * self.omega) as u64]); + let mut inner_sum = G1Projective::identity(); + for j in 0..self.v { + // Choose K_jb+t + let exponent = j * self.b * self.omega; + let two_pow_exponent = Scalar::from(2u64).pow(&[exponent as u64]); + + // Index K_jb+t + let start_index = (j * self.b + t) * self.omega; + let end_index = start_index + self.omega; + let k_jbt = &wnaf_digits[start_index..end_index]; + // Convert K_jb+t from NAF to scalar + let mut digit = Scalar::ZERO; + for (i, &bit) in k_jbt.iter().enumerate() { + if bit > 0 { + digit += Scalar::from(bit as u64) * Scalar::from(2u64).pow(&[i as u64]); + } else if bit < 0 { + digit += -Scalar::from(bit as u64) * Scalar::from(2u64).pow(&[i as u64]); + } + } + + inner_sum += point * digit * two_pow_exponent; + } + + result += inner_sum * two_pow_tw; + } + + result + } + + // This is just the same method but it uses wnaf instead of bits + pub fn mul_naive_better_wnaf(&self, scalar: &Scalar) -> G1Projective { + // Convert scalar to wnaf + let mut wnaf_digits = vec![]; + let mut scalar_bytes = scalar.to_bytes_le().to_vec(); + scalar_bytes.extend(vec![0u8; self.num_bits / 8 + 1 - scalar_bytes.len()]); // TODO: double check for rounding error + wnaf_form(&mut wnaf_digits, scalar_bytes, self.omega); + // let wnaf_digits = scalar_to_bits(*scalar); + let point = G1Projective::generator(); + let mut result = G1Projective::identity(); + // TODO: I think we need to pad here after wnaf + + // 1. Compute the precomputations + + // 2. iterate `w` bits and compute the scalar_mul + for t in 0..self.b { + let two_pow_tw = Scalar::from(2u64).pow(&[(t * self.omega) as u64]); + let mut inner_sum = G1Projective::identity(); + for j in 0..self.v { + // Choose K_jb+t + let exponent = j * self.b * self.omega; + let two_pow_exponent = Scalar::from(2u64).pow(&[exponent as u64]); + + // Index K_jb+t + let start_index = (j * self.b + t) * self.omega; + let end_index = start_index + self.omega; + let k_jbt = &wnaf_digits[start_index..end_index]; + // Convert K_jb+t from NAF to scalar + let mut digit = Scalar::ZERO; + for (i, &bit) in k_jbt.iter().enumerate() { + if bit > 0 { + digit += + Scalar::from(bit.abs() as u64) * Scalar::from(2u64).pow(&[(i) as u64]); + } else if bit < 0 { + digit += + -Scalar::from(bit.abs() as u64) * Scalar::from(2u64).pow(&[(i) as u64]); + } + } + inner_sum += point * digit * two_pow_exponent; + } + + result += inner_sum * two_pow_tw; + } + + result + } + + pub fn mul_naive_better_wnaf_precomputations(&self, scalar: &Scalar) -> G1Projective { + // Convert scalar to wnaf + let mut wnaf_digits = vec![]; + let mut scalar_bytes = scalar.to_bytes_le().to_vec(); + scalar_bytes.extend(vec![0u8; self.num_bits / 8 + 1 - scalar_bytes.len()]); // TODO: double check for rounding error + wnaf_form(&mut wnaf_digits, scalar_bytes, self.omega); + // let wnaf_digits = scalar_to_bits(*scalar); + let point = G1Affine::generator(); + let mut result = G1Projective::identity(); + // TODO: I think we need to pad here after wnaf + + // 1. Compute the precomputations + // Precomputation + let precomp = Self::precompute_point(point, self.omega, self.b, self.v); + + let now = std::time::Instant::now(); + + let mut windows = vec![vec![]; self.b]; + // 2. iterate `w` bits and compute the scalar_mul + for t in 0..self.b { + for j in 0..self.v { + let start_index = (j * self.b + t) * self.omega; + let end_index = start_index + self.omega; + let k_jbt = &wnaf_digits[start_index..end_index]; + + let mut s_exponent = 0; + let mut digit = 0; + + for (i, &bit) in k_jbt.iter().enumerate() { + if bit != 0 { + // Use bit shifting for 2^i + s_exponent = i; + digit = bit; + break; // In ω-NAF, only one non-zero digit per window + } + } + + if digit != 0 { + let abs_digit = digit.unsigned_abs() as u64; + let mut chosen_point = precomp[j] + [Self::sd_to_index(s_exponent, abs_digit as usize, self.omega as u32)]; + if digit < 0 { + chosen_point = -chosen_point; + } + windows[t].push(chosen_point); + } + } + } + + // Combine each sum in each window + let windows: Vec<_> = windows + .into_iter() + .map(|window| batch_addition(window)) + .collect(); + + // Combine windows + // for (t, window) in windows.into_iter().enumerate() { + // if t * self.omega == 0 { + // result += window + // } else if t * self.omega == 1 { + // result += G1Projective::from(window).double(); + // } else { + // // let inner_sum: G1Affine = window.into(); + + // let inner_sum = direct_doubling(t * self.omega, window); + + // result += inner_sum; + // } + // } + + for window in windows.into_iter().rev() { + for _ in 0..self.omega { + result = result.double() + } + + result += window; + } + + dbg!(now.elapsed().as_micros()); + + result + } + + // This is closer to the cleaned up version that does not have + // the precomps being done internally. + // + // These are computed in the constructor + pub fn mul_naive_better_wnaf_precomputations_final_msm( + &self, + scalars: &[Scalar], + ) -> G1Projective { + fn scalar_to_wnaf(scalar: Scalar, num_bits: usize, omega: usize) -> Vec { + let mut wnaf_digits = vec![]; + let mut scalar_bytes = scalar.to_bytes_le(); + // scalar_bytes.extend(vec![0u8; num_bits / 8 + 1 - scalar_bytes.len()]); // TODO: double check for rounding error + wnaf_form(&mut wnaf_digits, scalar_bytes, omega); + wnaf_digits + } + // let now = std::time::Instant::now(); + let scalars_wnaf_digits: Vec<_> = scalars + .into_iter() + .map(|scalar| scalar_to_wnaf(*scalar, self.num_bits, self.omega)) + .collect(); + // dbg!(now.elapsed().as_micros()); + // let wnaf_digits = scalar_to_wnaf(scalars[0], self.num_bits, self.omega); + // Convert scalar to wnaf + // let wnaf_digits = scalar_to_bits(*scalar); + let mut result = G1Projective::identity(); + + // let now = std::time::Instant::now(); + + let mut windows = vec![vec![]; self.b]; + // 2. iterate `w` bits and compute the scalar_mul + for t in 0..self.b { + for j in 0..self.v { + for (scalar_index, wnaf_digits) in scalars_wnaf_digits.iter().enumerate() { + let start_index = (j * self.b + t) * self.omega; + let end_index = start_index + self.omega; + if start_index > wnaf_digits.len() { + continue; + } + + let k_jbt = &wnaf_digits[start_index..end_index.min(wnaf_digits.len())]; + + let mut s_exponent = 0; + let mut digit = 0; + + for (i, &bit) in k_jbt.iter().enumerate() { + if bit != 0 { + // Use bit shifting for 2^i + s_exponent = i; + digit = bit; + break; // In ω-NAF, only one non-zero digit per window + } + } + + if digit != 0 { + let abs_digit = digit.unsigned_abs() as u64; + let mut chosen_point = self.precomputed_points[scalar_index][j] + [Self::sd_to_index(s_exponent, abs_digit as usize, self.omega as u32)]; + if digit < 0 { + chosen_point = -chosen_point; + } + windows[t].push(chosen_point); + } + } + } + } + + // Combine each sum in each window + // let windows: Vec<_> = windows + // .into_iter() + // .map(|window| batch_addition(window)) + // .collect(); + // let now = std::time::Instant::now(); + let windows = multi_batch_addition_diff_stride(windows); + // dbg!(now.elapsed().as_micros()); + // Combine windows + // for (t, window) in windows.into_iter().enumerate() { + // if t * self.omega == 0 { + // result += window + // } else if t * self.omega == 1 { + // result += G1Projective::from(window).double(); + // } else { + // // let inner_sum: G1Affine = window.into(); + + // let inner_sum = direct_doubling(t * self.omega, window); + + // result += inner_sum; + // } + // } + for window in windows.into_iter().rev() { + for _ in 0..self.omega { + result = result.double() + } + + result += window; + } + + // dbg!(now.elapsed().as_micros()); + + result + } + + fn sd_to_index(s_exp: usize, d: usize, w: u32) -> usize { + s_exp * (1 << (w - 2)) + (d - 1) / 2 + } + + fn precompute_point_old( + point: G1Affine, + omega: usize, + b: usize, + v: usize, + ) -> Vec> { + let point = G1Projective::from(point); + + let inner_size = omega * (1 << (omega - 2)); + let mut precomp = vec![vec![G1Projective::identity(); inner_size]; v]; + + for s in 0..omega { + for d in (1..1 << (omega - 1)).step_by(2) { + let index = s * (1 << (omega - 2)) + (d - 1) / 2; + let sd = (1 << s) * d; + precomp[0][index] = point * (&Scalar::from(sd as u64)); + } + } + + for j in 1..v { + let factor = Scalar::from(2u64).pow(&[(j * omega * b) as u64]); + for index in 0..inner_size { + precomp[j][index] = precomp[0][index] * (&factor); + } + } + + let precomp: Vec<_> = precomp + .into_iter() + .map(|points| g1_batch_normalize(&points)) + .collect(); + + precomp + } + + fn precompute_point(point: G1Affine, omega: usize, b: usize, v: usize) -> Vec> { + // d in the paper is just odd multiples + fn precompute_odd_multiples(base: G1Affine, w: usize) -> Vec { + let base = G1Projective::from(base); + let num_points = (1 << (w - 1)) / 2; // (2^(w-1)) / 2 points to compute + let mut results = vec![G1Projective::identity(); num_points]; + + // Compute 2P + let double_base = base.double(); + + // 1P is just the base point + results[0] = base; + + // Compute odd multiples: 3P, 5P, ..., (2^(w-1) - 1)P + for i in 1..num_points { + results[i] = results[i - 1] + double_base; + } + + results + } + + // let inner_size = omega * (1 << (omega - 2)); + let mut precomp = Vec::new(); + + let d_vec = precompute_odd_multiples(point, omega); + use rayon::prelude::*; + // Compute G_0 + let mut inner = Vec::new(); + inner.push(d_vec.clone()); + for s_exp in 1..omega { + let doubled = inner + .last() + .unwrap() + .par_iter() + .map(|p| p.double()) + .collect(); + inner.push(doubled) + } + precomp.push(inner.into_iter().flatten().collect::>()); + + // Now scale those G_j + for j in 1..v { + let mut scaled_inner: Vec<_> = precomp + .last() + .unwrap() + .par_iter() + .map(|inner| { + let mut res = *inner; + for _ in 0..omega * b { + res = res.double(); + } + res + }) + .collect(); + + precomp.push(scaled_inner.into_iter().collect::>()) + } + + let precomp: Vec<_> = precomp + .into_iter() + .map(|points| g1_batch_normalize(&points)) + .collect(); + + precomp + } +} + +fn direct_doubling(r: usize, point: G1Affine) -> G1Affine { + if point.is_identity().into() { + return G1Affine::identity(); + } + + // The below algorithm assumes r > 0 + // We could simply disallow it and panic + // I chose to return the the point since 2^0 * P = P + if r == 0 { + return point; + } + + // This is just a optimization, the algorithm, does + // allow this. + if r == 1 { + return G1Projective::from(point).double().into(); + } + + let mut previous_a_i = point.x(); + let mut previous_b_i = point.x().square().mul3(); + let mut previous_c_i = -point.y(); + let mut c_prod = previous_c_i; + + let mut current_a_i = Fp::ZERO; + let mut current_b_i = Fp::ZERO; + let mut current_c_i = Fp::ZERO; + + for i in 1..r { + current_a_i = previous_b_i.square() - previous_a_i.mul8() * previous_c_i.square(); + current_b_i = current_a_i.square().mul3(); + current_c_i = -previous_c_i.square().square().mul8() + - previous_b_i * (current_a_i - previous_a_i * previous_c_i.square() * Fp::from(4u64)); + c_prod *= current_c_i; + + previous_a_i = current_a_i; + previous_b_i = current_b_i; + previous_c_i = current_c_i; + } + + let a_r = current_a_i; + let b_r = current_b_i; + let c_r = current_c_i; + + // TODO: We square the same values etc below multiple times + // TODO: we could optimize and remove these, see for example c_r.square + + let d_r = a_r.mul3() * Fp::from(4u64) * c_r.square() - b_r.square(); + + let mut denom_prod = c_prod; + let denom = Fp::from(2u64).pow(&[r as u64]) * denom_prod; + let denom = denom.invert().unwrap(); + + let denom_sq = denom.square(); + let denom_cu = denom_sq * denom; + + // Compute x_2r + let numerator = b_r.square() - c_r.square().mul8() * a_r; + let x2r = numerator * denom_sq; + + // Compute y_2r + let numerator = c_r.square().square().mul8() - b_r * d_r; + let y2r = numerator * denom_cu; + + G1Affine::from_raw_unchecked(x2r, y2r, false) +} + +#[test] +fn direct_double() { + let point = G1Affine::generator(); + + for r in 2..10 { + // let r = 2; + let expected = (point * Scalar::from(2u64).pow(&[r as u64])).into(); + + let got = direct_doubling(r, point); + + assert_eq!(got, expected); + } +} + +fn random_points(num_points: usize) -> Vec { + (0..num_points) + .into_iter() + .map(|_| G1Projective::random(&mut rand::thread_rng()).into()) + .collect() +} + +#[test] +fn tsaur_chau() { + let ts = TsaurChou::new(5, 26, &[G1Affine::generator()]); + let scalar = -Scalar::from(1u64); + + let expected = G1Projective::generator() * scalar; + + let result = ts.mul_naive(&scalar); + assert!(result == expected); + let result = ts.mul_naive_better(&scalar); + assert!(result == expected); + + let result = ts.mul_naive_better_wnaf(&scalar); + assert!(result == expected); + + let result = ts.mul_naive_better_wnaf_precomputations(&scalar); + assert!(result == expected); + + let result = ts.mul_naive_better_wnaf_precomputations_final_msm(&[scalar]); + assert!(result == expected); +} + +#[test] +fn tsaur_chau_msm() { + let num_points = 64; + let points = random_points(num_points); + let ts = TsaurChou::new(5, 7, &points); // (5,7), (5,4), (4,12), (6,3), (8,2), (8,4), (8,1) + dbg!(ts.num_precomputed_points()); + + let scalars: Vec<_> = (0..num_points) + .into_iter() + .map(|_| Scalar::random(&mut rand::thread_rng())) + .collect(); + + let mut expected = G1Projective::identity(); + for (scalar, point) in scalars.iter().zip(points.iter()) { + expected += G1Projective::from(*point) * scalar + } + let now = std::time::Instant::now(); + let result = ts.mul_naive_better_wnaf_precomputations_final_msm(&scalars); + dbg!(now.elapsed().as_micros()); + assert!(result == expected); +} + +#[test] +fn wnaf_smoke_test() { + let s = Scalar::from(1065142573068u64); + let mut wnaf = vec![]; + // let mut wnaf_digits = vec![]; + let mut scalar_bytes = s.to_bytes_le().to_vec(); + scalar_bytes.extend(vec![0u8; 258 / 8 + 1 - scalar_bytes.len()]); + // wnaf_form(&mut wnaf_digits, scalar_bytes, self.omega); + wnaf_form(&mut wnaf, scalar_bytes, 3); + + dbg!(wnaf.chunks_exact(3).collect::>()); + + let mut result = Scalar::ZERO; + for (i, digit) in wnaf.into_iter().enumerate() { + if digit > 0 { + result += Scalar::from(digit.abs() as u64) * Scalar::from(2u64).pow(&[(i) as u64]); + } else if digit < 0 { + result += -Scalar::from(digit.abs() as u64) * Scalar::from(2u64).pow(&[(i) as u64]); + } + } + assert_eq!(result, s); +} + +#[test] +fn smoke_test_generator_scalar_mul() { + let ll = LimLee::new(8, 8, &[]); + let scalar = -Scalar::from(2u64); + + let expected = G1Projective::generator() * scalar; + + let result = ll.scalar_mul_naive(scalar); + assert!(result == expected); + + let got = ll.scalar_mul_eq3(scalar); + assert_eq!(got, result); + + let got = ll.scalar_mul_eq4(scalar); + assert_eq!(got, result); + + let got = ll.scalar_mul_eq5(scalar); + assert_eq!(got, result) +} + +#[test] +fn smoke_test_lim_lee_msm() { + let num_points = 1; + let points = random_points(num_points); + let ll = LimLee::new(8, 2, &points); // (8,2), (4,16), (5,4) + + let scalars: Vec<_> = (0..num_points) + .into_iter() + .map(|i| Scalar::from(i as u64)) + // .map(|i| Scalar::random(&mut rand::thread_rng())) + .collect(); + + let mut expected = G1Projective::identity(); + for (scalar, point) in scalars.iter().zip(points.iter()) { + expected += G1Projective::from(*point) * scalar + } + let now = std::time::Instant::now(); + let got = ll.msm(&scalars); + dbg!(now.elapsed().as_micros()); + dbg!(ll.num_precomputed_points()); + assert_eq!(got, expected); +} + +pub fn scalar_to_bits(s: Scalar) -> [u8; 256] { + let scalar_bytes = s.to_bytes_le(); + bytes_to_bits(scalar_bytes) +} +fn bytes_to_bits(bytes: [u8; 32]) -> [u8; 256] { + let mut bit_vector = Vec::with_capacity(256); + for byte in bytes { + for i in 0..8 { + bit_vector.push(((byte >> i) & 0x01) as u8) + } + } + bit_vector.try_into().unwrap() +} + +fn bits_to_byte(bits: &[u8]) -> u8 { + assert!( + bits.len() <= 8, + "currently we are returning a u8, so can only do 8 bits." + ); + bits.iter() + .rev() + .fold(0, |acc, &bit| (acc << 1) | (bit & 1)) +} + +#[test] +fn compute_padded_scalar() { + struct TestCase { + scalar_size: u32, + divider: u32, + expected: u32, + } + + let cases = vec![ + // TODO: remove this and generalize + TestCase { + scalar_size: 255, + divider: 4, + expected: 256, + }, + TestCase { + scalar_size: 256, + divider: 4, + expected: 256, + }, + TestCase { + scalar_size: 100, + divider: 3, + expected: 102, + }, + ]; + + for case in cases { + let got = LimLee::compute_padded_scalar_bits(case.scalar_size, case.divider); + assert_eq!(got, case.expected) + } +} diff --git a/cryptography/bls12_381/src/seokim.rs b/cryptography/bls12_381/src/seokim.rs new file mode 100644 index 00000000..c55d2eaf --- /dev/null +++ b/cryptography/bls12_381/src/seokim.rs @@ -0,0 +1,568 @@ +// Implements https://www.mdpi.com/1424-8220/13/7/9483 + +use crate::batch_add::multi_batch_addition_diff_stride; +use crate::g1_batch_normalize; +use crate::limlee::scalar_to_bits; +use crate::wnaf::wnaf_form; +use blstrs::G1Affine; +use blstrs::{G1Projective, Scalar}; +use ff::Field; +use ff::PrimeField; +use group::prime::PrimeCurveAffine; +use group::Group; +use rayon::prelude::*; + +pub struct SeoKim { + w: usize, + a: usize, + l: usize, + z: usize, + + precomputed_points: Vec, +} + +// Precomputations needed for a single point +type PrecomputationsForPoint = Vec>; + +impl SeoKim { + pub fn new(omega: usize, points: &[G1Affine]) -> Self { + let num_bits = Self::calculate_padded_size((Scalar::NUM_BITS + 1) as usize, omega * omega); + let a = num_bits.div_ceil(omega); + + let z = a.div_ceil(omega); + + // let mut precomputed_points: Vec = Vec::new(); + // for point in points { + // precomputed_points.push(Self::precompute_point(*point, omega, z)); + // } + + let mut precomputed_points: Vec = points + .into_par_iter() + .map(|point| Self::precompute_point(*point, omega, z)) + .collect(); + + Self { + w: omega, + a: a as usize, + l: num_bits, + z, + precomputed_points, + } + } + + fn calculate_padded_size(l: usize, w: usize) -> usize { + let a = (l + w - 1) / w; // This is ⌈l/ω⌉ + let padded_size = a * w; + // TODO: if statement not needed, if we do div_ceil + let padding_zeros = if l % w == 0 { 0 } else { padded_size - l }; + padding_zeros + l + } + + fn scalar_mul_naive(&self, scalar: &Scalar) -> G1Projective { + // Convert scalar to bits and pad it to the necessary length + let mut wnaf_digits = scalar_to_bits(*scalar).to_vec(); + wnaf_digits.extend(vec![0u8; self.l - wnaf_digits.len()]); + + let point = G1Projective::generator(); + + let mut result = G1Projective::identity(); + + for t in 0..self.z { + // t is used to scan a square + let square_offset = t * self.w * self.w; + for i in 0..self.w { + // i is used to scan a particular row + // + // + + // Collect all of the necessary bits that differ by a factor of omega + let digits = select_elements(&wnaf_digits, self.w as usize, t as usize, i as usize); + // I need to figure out the bit position for this + for (index, digit) in digits.into_iter().enumerate() { + let exponent = square_offset + i + index * self.w; + result += point + * Scalar::from(*digit as u64) + * Scalar::from(2u64).pow(&[exponent as u64]); + } + } + } + result + } + + fn scalar_mul_naive_wnaf(&self, scalar: &Scalar) -> G1Projective { + fn scalar_to_wnaf(scalar: Scalar, num_bits: usize, omega: usize) -> Vec { + let mut wnaf_digits = vec![]; + let scalar_bytes = scalar.to_bytes_le().to_vec(); + wnaf_form(&mut wnaf_digits, scalar_bytes, omega); + + // TODO: the wnaf algorithm will pad unecessary zeroes + // which then makes the padding algorithm below pad it even more in some cases. + // We can either fix wnaf_form or remove the extra omega zeroes and then pad + + // Pad wnaf_digits to the next multiple of w^2 + let len = wnaf_digits.len(); + let w_squared = omega * omega; + let num_sectors = (len + w_squared - 1) / w_squared; + let padded_len = num_sectors * w_squared; + wnaf_digits.extend(vec![0i64; padded_len - len]); + + wnaf_digits + } + // Convert scalar to bits and pad it to the necessary length + // let mut wnaf_digits = scalar_to_bits(*scalar).to_vec(); + // wnaf_digits.extend(vec![0u8; self.l - wnaf_digits.len()]); + + let mut wnaf_digits = scalar_to_wnaf(*scalar, self.l, self.w); + + let point = G1Projective::generator(); + + let mut result = G1Projective::identity(); + + for t in 0..self.z { + // t is used to scan a square + let square_offset = t * self.w * self.w; + for i in 0..self.w { + // i is used to scan a particular row + // + // + + // Collect all of the necessary bits that differ by a factor of omega + let digits = select_elements(&wnaf_digits, self.w as usize, t as usize, i as usize); + + for (index, digit) in digits.into_iter().enumerate() { + if *digit == 0 { + continue; + } + + let is_negative = digit.is_negative(); + + let exponent = square_offset + i + index * self.w; + let two_pow = Scalar::from(2u64).pow(&[exponent as u64]); + let digit = Scalar::from(digit.unsigned_abs()); + + if is_negative { + result -= point * digit * two_pow; + } else { + result += point * digit * two_pow; + } + } + } + } + result + } + + fn scalar_mul_naive_wnaf_iterated(&self, scalar: &Scalar) -> G1Projective { + fn scalar_to_wnaf(scalar: Scalar, num_bits: usize, omega: usize) -> Vec { + let mut wnaf_digits = vec![]; + let scalar_bytes = scalar.to_bytes_le().to_vec(); + wnaf_form(&mut wnaf_digits, scalar_bytes, omega); + + // TODO: the wnaf algorithm will pad unnecessary zeroes + // which then makes the padding algorithm below pad it even more in some cases. + // We can either fix wnaf_form or remove the extra omega zeroes and then pad + + // Pad wnaf_digits to the next multiple of w^2 + let len = wnaf_digits.len(); + let w_squared = omega * omega; + let num_sectors = (len + w_squared - 1) / w_squared; + let padded_len = num_sectors * w_squared; + wnaf_digits.extend(vec![0i64; padded_len - len]); + + wnaf_digits + } + + let mut result = G1Projective::identity(); + let point = G1Projective::generator(); + + let mut wnaf_digits = scalar_to_wnaf(*scalar, self.l, self.w); + for t in 0..self.z { + // t is used to scan a square + let square_offset = t * self.w * self.w; + for i in 0..self.w { + // i is used to scan a particular row + // + // + + // Collect all of the necessary bits that differ by a factor of omega + let digits = select_elements(&wnaf_digits, self.w as usize, t as usize, i as usize); + + let mut total_value = 0; + for (index, digit) in digits.iter().enumerate() { + total_value += (**digit as i64) * (1 << index as i64 * self.w as i64); + } + + if total_value == 0 { + continue; + } + + let is_negative = total_value.is_negative(); + let two_pow_offset = Scalar::from(2u64).pow(&[square_offset as u64]); + let two_pow_i = Scalar::from(2u64).pow(&[i as u64]); + + if is_negative { + result -= point + * Scalar::from(total_value.unsigned_abs()) + * two_pow_offset + * two_pow_i; + } else { + result += point + * Scalar::from(total_value.unsigned_abs()) + * two_pow_offset + * two_pow_i; + } + } + } + result + } + + fn scalar_mul_precomps_wnaf(&self, scalar: &Scalar) -> G1Projective { + fn scalar_to_wnaf(scalar: Scalar, num_bits: usize, omega: usize) -> Vec { + let mut wnaf_digits = vec![]; + let scalar_bytes = scalar.to_bytes_le().to_vec(); + wnaf_form(&mut wnaf_digits, scalar_bytes, omega); + + // TODO: the wnaf algorithm will pad unnecessary zeroes + // which then makes the padding algorithm below pad it even more in some cases. + // We can either fix wnaf_form or remove the extra omega zeroes and then pad + + // Pad wnaf_digits to the next multiple of w^2 + let len = wnaf_digits.len(); + let w_squared = omega * omega; + let num_sectors = (len + w_squared - 1) / w_squared; + let padded_len = num_sectors * w_squared; + wnaf_digits.extend(vec![0i64; padded_len - len]); + + wnaf_digits + } + + let mut result = G1Projective::identity(); + let point = G1Projective::generator(); + + let mut square_precomputations = Vec::new(); + let mut precomputations = Vec::new(); + // numbers are of the form a_0 + 2^w a_1 + 2^2w a_2 +... a_w 2^w*w + for i in 1..(1 << self.w * self.w) { + precomputations.push(point * Scalar::from(i as u64)); + } + square_precomputations.push(precomputations); + + // Precompute the values across rows, across the square + for k in 0..self.z { + // Take the last + let last_square = square_precomputations.last().unwrap().clone(); + // double all elements in the last square w*w times + let shifted_square: Vec<_> = last_square + .into_par_iter() + .map(|mut point| { + for _ in 0..(self.w * self.w) { + point = point.double(); + } + point + }) + .collect(); + + square_precomputations.push(shifted_square); + } + + let mut wnaf_digits = scalar_to_wnaf(*scalar, self.l, self.w); + for i in (0..self.w).rev() { + result = result.double(); + + for t in (0..self.z) { + // t is used to scan a square + // i is used to scan a particular row + // + // + + // Collect all of the necessary bits that differ by a factor of omega + let digits = select_elements(&wnaf_digits, self.w as usize, t as usize, i as usize); + + let mut total_value = 0; + for (index, digit) in digits.iter().enumerate() { + total_value += (**digit as i64) * (1 << index as i64 * self.w as i64); + } + + if total_value == 0 { + continue; + } + + let is_negative = total_value.is_negative(); + // let two_pow_offset = Scalar::from(2u64).pow(&[square_offset as u64]); + // let two_pow_i = Scalar::from(2u64).pow(&[i as u64]); + + let mut chosen_point = + square_precomputations[t][(total_value.unsigned_abs() as usize - 1)]; + + // for _ in 0..i { + // chosen_point = chosen_point.double() + // } + // for _ in 0..square_offset { + // chosen_point = chosen_point.double() + // } + + if is_negative { + result -= chosen_point; + } else { + result += chosen_point; + } + } + } + result + } + fn precompute_point(point: G1Affine, omega: usize, z: usize) -> PrecomputationsForPoint { + let point = G1Projective::from(point); + + let mut square_precomputations = Vec::new(); + let mut precomputations = Vec::new(); + // numbers are of the form a_0 + 2^w a_1 + 2^2w a_2 +... a_w 2^w*w + for i in 1..(1 << omega * omega) { + precomputations.push(point * Scalar::from(i as u64)); + } + square_precomputations.push(precomputations); + + // Precompute the values across rows, across the square + for k in 0..z { + // Take the last + let last_square = square_precomputations.last().unwrap().clone(); + // double all elements in the last square w*w times + let shifted_square: Vec<_> = last_square + .into_par_iter() + .map(|mut point| { + for _ in 0..(omega * omega) { + point = point.double(); + } + point + }) + .collect(); + + square_precomputations.push(shifted_square); + } + + square_precomputations + .into_par_iter() + .map(|sp| g1_batch_normalize(&sp)) + .collect() + } + fn msm(&self, scalars: &[Scalar]) -> G1Projective { + fn scalar_to_wnaf(scalar: Scalar, num_bits: usize, omega: usize) -> Vec { + let mut wnaf_digits = vec![]; + let scalar_bytes = scalar.to_bytes_le().to_vec(); + wnaf_form(&mut wnaf_digits, scalar_bytes, omega); + + // TODO: the wnaf algorithm will pad unnecessary zeroes + // which then makes the padding algorithm below pad it even more in some cases. + // We can either fix wnaf_form or remove the extra omega zeroes and then pad + + // Pad wnaf_digits to the next multiple of w^2 + let len = wnaf_digits.len(); + let w_squared = omega * omega; + let num_sectors = (len + w_squared - 1) / w_squared; + let padded_len = num_sectors * w_squared; + wnaf_digits.extend(vec![0i64; padded_len - len]); + + wnaf_digits + } + + let mut result = G1Projective::identity(); + let now = std::time::Instant::now(); + let scalars_wnaf_digits: Vec<_> = scalars + .into_iter() + .map(|scalar| scalar_to_wnaf(*scalar, self.l, self.w)) + .collect(); + let mut rows = vec![vec![]; self.w]; + + for (scalar_index, wnaf_digits) in scalars_wnaf_digits.into_iter().enumerate() { + for i in (0..self.w).rev() { + for t in (0..self.z) { + // t is used to scan a square + // i is used to scan a particular row + // + // + + // Collect all of the necessary bits that differ by a factor of omega + let digits = + select_elements_iter(&wnaf_digits, self.w as usize, t as usize, i as usize); + + let mut total_value = 0; + for (index, digit) in digits.enumerate() { + total_value += (digit as i64) * (1 << index as i64 * self.w as i64); + } + + if total_value == 0 { + continue; + } + + let is_negative = total_value.is_negative(); + // let two_pow_offset = Scalar::from(2u64).pow(&[square_offset as u64]); + // let two_pow_i = Scalar::from(2u64).pow(&[i as u64]); + + let mut chosen_point = self.precomputed_points[scalar_index][t] + [(total_value.unsigned_abs() as usize - 1)]; + + let chosen_point = if is_negative { + -chosen_point + } else { + chosen_point + }; + rows[i].push(chosen_point); + } + } + } + + // Sum all rows together + let summed_rows = multi_batch_addition_diff_stride(rows); + + // Combine rows together (they differ by a factor of 2) + let mut result = G1Projective::identity(); + for row in summed_rows.into_iter().rev() { + result = result.double(); + result += row; + } + + dbg!(now.elapsed().as_micros()); + + result + } +} + +fn select_elements(vector: &[T], w: usize, sector: usize, offset: usize) -> Vec<&T> { + // Calculate the total number of sectors + let total_sectors = vector.len() / (w * w); + + // Validate that the vector length is a multiple of w squared + if vector.len() % (w * w) != 0 { + panic!( + "The size of the vector must be a multiple of w squared. got = {}, expected = {}", + vector.len(), + w * w + ); + } + // Validate that the sector index is within the valid range + if sector >= total_sectors { + panic!("Sector index out of range."); + } + // Validate that the offset is within the valid range + if offset >= w { + panic!("Offset must be in the range [0, w - 1]."); + } + // Calculate the starting index of the sector + let sector_start = sector * w * w; + // Collect the selected elements + let selected_elements: Vec<&T> = (0..w) + .map(|k| &vector[sector_start + offset + k * w]) + .collect(); + selected_elements +} +fn select_elements_iter<'a, T: Copy + Clone>( + vector: &'a [T], + w: usize, + sector: usize, + offset: usize, +) -> impl Iterator + 'a { + // Calculate the total number of sectors + let total_sectors = vector.len() / (w * w); + + // Validate that the vector length is a multiple of w squared + if vector.len() % (w * w) != 0 { + panic!( + "The size of the vector must be a multiple of w squared. got = {}, expected = {}", + vector.len(), + w * w + ); + } + // Validate that the sector index is within the valid range + if sector >= total_sectors { + panic!("Sector index out of range."); + } + // Validate that the offset is within the valid range + if offset >= w { + panic!("Offset must be in the range [0, w - 1]."); + } + // Calculate the starting index of the sector + let sector_start = sector * w * w; + // Collect the selected elements + (0..w).map(move |k| vector[sector_start + offset + k * w]) +} + +#[test] +fn test_debug_vector_selector() { + let w = 4; + let num_sectors = 4; + // Create a vector with 3 sectors, each of size 16 (4*4), total size 48 + let vector: Vec = (0..(num_sectors * w * w)) + .map(|i| format!("b_{}", i)) + .collect(); + + let sector = 1; // Choose the sector index (0-based) + let offset = 2; // Starting offset within the sector + + let selected = select_elements(&vector, w, sector, offset); + + let t = sector; + let i = offset; + let square_offset = t * w * w; + + for index_ in 0..w { + let exp = square_offset + i + index_ * w; + dbg!(exp); + } + + println!( + "Selected elements from sector {} with offset {}:", + sector, offset + ); + println!("{:?}", selected); +} + +#[test] +fn test_seo_kim_naive_scalar_mul() { + let scalar = -Scalar::from(2u64); + let result = G1Projective::generator() * scalar; + + let w = 4; + let sk = SeoKim::new(w, &[G1Affine::generator()]); + + let got = sk.scalar_mul_naive(&scalar); + assert_eq!(got, result); + + let got = sk.scalar_mul_naive_wnaf(&scalar); + assert_eq!(got, result); + + let got = sk.scalar_mul_naive_wnaf_iterated(&scalar); + assert_eq!(got, result); + + // let got = sk.scalar_mul_precomps_wnaf(&scalar); + // assert_eq!(got, result); + + let got = sk.msm(&[scalar]); + assert_eq!(got, result); +} + +#[test] +fn test_seokim_msm() { + let num_points = 64; + let points = random_points(num_points); + + let w = 4; + let sk = SeoKim::new(w, &points); + + let scalars: Vec<_> = (0..num_points) + .into_iter() + .map(|_| Scalar::random(&mut rand::thread_rng())) + .collect(); + + let mut expected = G1Projective::identity(); + for (scalar, point) in scalars.iter().zip(points.iter()) { + expected += G1Projective::from(*point) * scalar + } + + let got = sk.msm(&scalars); + assert_eq!(got, expected); +} + +fn random_points(num_points: usize) -> Vec { + (0..num_points) + .into_iter() + .map(|_| G1Projective::random(&mut rand::thread_rng()).into()) + .collect() +} diff --git a/cryptography/bls12_381/src/simple_msm.rs b/cryptography/bls12_381/src/simple_msm.rs new file mode 100644 index 00000000..157f69bb --- /dev/null +++ b/cryptography/bls12_381/src/simple_msm.rs @@ -0,0 +1,186 @@ +use blstrs::{G1Affine, G1Projective, Scalar}; +use ff::Field; +use group::Group; + +use crate::{ + batch_add::{batch_addition, batch_addition_diff_stride, multi_batch_addition_diff_stride}, + wnaf::wnaf_form, +}; + +// This just generalizes the double and add algorithm +pub struct SimpleMsm; + +pub fn msm_sjf(points: &[G1Affine], scalars: &[Scalar]) -> G1Projective { + let mut scalars_bytes: Vec<_> = scalars + .into_iter() + .map(|scalar| scalar.to_bytes_le()) + .collect(); + let scalars_jsf = calculate_dsjsf(&scalars_bytes); + let mut buckets = vec![vec![]; 256]; + for (scalar_index, scalar_bits) in scalars_jsf.into_iter().enumerate() { + for (index, bit) in scalar_bits.into_iter().enumerate() { + if bit < 0 { + buckets[index].push(-points[scalar_index]); + } else if bit > 0 { + buckets[index].push(points[scalar_index]); + } + } + } + + let mut result = G1Projective::identity(); + let summed_windows = multi_batch_addition_diff_stride(buckets); + for (window) in summed_windows.into_iter().rev() { + result = result.double(); + result += window; + } + + result +} +pub fn msm(points: &[G1Affine], scalars: &[Scalar]) -> G1Projective { + let scalars_bits: Vec<_> = scalars.into_iter().map(|s| scalar_to_bits(*s)).collect(); + + let mut buckets = vec![vec![]; 256]; + + for (scalar_index, scalar_bits) in scalars_bits.into_iter().enumerate() { + // iterate over scalar + for (index, bit) in scalar_bits.into_iter().enumerate() { + if bit != 0 { + buckets[index].push(points[scalar_index]); + } + } + } + + let mut result = G1Projective::identity(); + let summed_windows = multi_batch_addition_diff_stride(buckets); + for (window) in summed_windows.into_iter().rev() { + result = result.double(); + result += window; + } + + result +} + +pub fn scalar_to_bits(s: Scalar) -> [u8; 256] { + let scalar_bytes = s.to_bytes_le(); + bytes_to_bits(scalar_bytes) +} +fn bytes_to_bits(bytes: [u8; 32]) -> [u8; 256] { + let mut bit_vector = Vec::with_capacity(256); + for byte in bytes { + for i in 0..8 { + bit_vector.push(((byte >> i) & 0x01) as u8) + } + } + bit_vector.try_into().unwrap() +} + +pub fn calculate_dsjsf(x: &[[u8; 32]]) -> Vec> { + let d = x.len(); + let max_len = x.iter().map(|xi| xi.len()).max().unwrap_or(0); + let mut result = vec![vec![0i8; 0]; d]; + let mut x_copy: Vec> = x.iter().map(|&xi| xi.to_vec()).collect(); + + let mut j = 0; + let mut a = vec![Vec::new(); 2]; + + loop { + if x_copy.iter().all(|xi| xi.iter().all(|&b| b == 0)) { + break; + } + + let mut xj = vec![0i8; d]; + a.push(Vec::new()); + + for k in 0..d { + if let Some(&last_byte) = x_copy[k].last() { + xj[k] = (last_byte & 1) as i8; + if xj[k] == 1 { + a[j].push(k); + } + } + } + + for k in 0..d { + if x_copy[k].len() > 1 || (x_copy[k].len() == 1 && x_copy[k][0] > 1) { + let next_bit = ((x_copy[k].last().unwrap_or(&0) >> 1) & 1) as i8; + if next_bit == 1 { + a[j + 1].push(k); + } + } + } + + if a[j + 1].iter().all(|&k| a[j].contains(&k)) { + for &k in &a[j + 1] { + xj[k] = -xj[k]; + } + a[j + 1].clear(); + } else { + for &k in &a[j] { + if !a[j + 1].contains(&k) { + xj[k] = -xj[k]; + } + } + a[j + 1] = a[j] + .iter() + .cloned() + .chain(a[j + 1].iter().cloned()) + .collect(); + } + + for k in 0..d { + result[k].insert(0, xj[k]); + if !x_copy[k].is_empty() { + let mut borrow = (xj[k] < 0) as u8; + for byte in x_copy[k].iter_mut().rev() { + let (new_byte, new_borrow) = byte.overflowing_sub(borrow); + *byte = new_byte; + if new_borrow { + borrow = 1; + } else { + break; + } + } + divide_by_two(&mut x_copy[k]); + } + } + + j += 1; + } + + result +} + +/// Helper function to divide a big-endian byte array by 2 +fn divide_by_two(num: &mut Vec) { + let mut carry = 0; + for byte in num.iter_mut().rev() { + let new_carry = *byte & 1; + *byte = (*byte >> 1) | (carry << 7); + carry = new_carry; + } + while num.len() > 1 && num[0] == 0 { + num.remove(0); + } +} + +fn random_points(num_points: usize) -> Vec { + use group::Group; + (0..num_points) + .into_iter() + .map(|_| G1Projective::random(&mut rand::thread_rng()).into()) + .collect() +} + +#[test] +fn test_simple_msm() { + use ff::Field; + let num_points = 64; + let points = random_points(num_points); + let scalars: Vec<_> = (0..num_points) + .into_iter() + .map(|i| Scalar::random(&mut rand::thread_rng())) + .collect(); + let now = std::time::Instant::now(); + msm_sjf(&points, &scalars); + dbg!(now.elapsed().as_micros()); +} diff --git a/cryptography/bls12_381/src/wnaf.rs b/cryptography/bls12_381/src/wnaf.rs new file mode 100644 index 00000000..bd48fd37 --- /dev/null +++ b/cryptography/bls12_381/src/wnaf.rs @@ -0,0 +1,124 @@ +/* +TAKEN from group crate as they don't expose wnaf. +Add proper reference here + +*/ + +/// This struct represents a view of a sequence of bytes as a sequence of +/// `u64` limbs in little-endian byte order. It maintains a current index, and +/// allows access to the limb at that index and the one following it. Bytes +/// beyond the end of the original buffer are treated as zero. +struct LimbBuffer<'a> { + buf: &'a [u8], + cur_idx: usize, + cur_limb: u64, + next_limb: u64, +} + +impl<'a> LimbBuffer<'a> { + fn new(buf: &'a [u8]) -> Self { + let mut ret = Self { + buf, + cur_idx: 0, + cur_limb: 0, + next_limb: 0, + }; + + // Initialise the limb buffers. + ret.increment_limb(); + ret.increment_limb(); + ret.cur_idx = 0usize; + + ret + } + + fn increment_limb(&mut self) { + self.cur_idx += 1; + self.cur_limb = self.next_limb; + match self.buf.len() { + // There are no more bytes in the buffer; zero-extend. + 0 => self.next_limb = 0, + + // There are fewer bytes in the buffer than a u64 limb; zero-extend. + x @ 1..=7 => { + let mut next_limb = [0; 8]; + next_limb[..x].copy_from_slice(self.buf); + self.next_limb = u64::from_le_bytes(next_limb); + self.buf = &[]; + } + + // There are at least eight bytes in the buffer; read the next u64 limb. + _ => { + let (next_limb, rest) = self.buf.split_at(8); + self.next_limb = u64::from_le_bytes(next_limb.try_into().unwrap()); + self.buf = rest; + } + } + } + + fn get(&mut self, idx: usize) -> (u64, u64) { + assert!([self.cur_idx, self.cur_idx + 1].contains(&idx)); + if idx > self.cur_idx { + self.increment_limb(); + } + (self.cur_limb, self.next_limb) + } +} + +/// Replaces the contents of `wnaf` with the w-NAF representation of a little-endian +/// scalar. +pub(crate) fn wnaf_form>(wnaf: &mut Vec, c: S, window: usize) { + // Required by the NAF definition + debug_assert!(window >= 2); + // Required so that the NAF digits fit in i64 + debug_assert!(window <= 64); + + let bit_len = c.as_ref().len() * 8; + + wnaf.truncate(0); + wnaf.reserve(bit_len); + + // Initialise the current and next limb buffers. + let mut limbs = LimbBuffer::new(c.as_ref()); + + let width = 1u64 << window; + let window_mask = width - 1; + + let mut pos = 0; + let mut carry = 0; + while pos < bit_len { + // Construct a buffer of bits of the scalar, starting at bit `pos` + let u64_idx = pos / 64; + let bit_idx = pos % 64; + let (cur_u64, next_u64) = limbs.get(u64_idx); + let bit_buf = if bit_idx + window < 64 { + // This window's bits are contained in a single u64 + cur_u64 >> bit_idx + } else { + // Combine the current u64's bits with the bits from the next u64 + (cur_u64 >> bit_idx) | (next_u64 << (64 - bit_idx)) + }; + + // Add the carry into the current window + let window_val = carry + (bit_buf & window_mask); + + if window_val & 1 == 0 { + // If the window value is even, preserve the carry and emit 0. + // Why is the carry preserved? + // If carry == 0 and window_val & 1 == 0, then the next carry should be 0 + // If carry == 1 and window_val & 1 == 0, then bit_buf & 1 == 1 so the next carry should be 1 + wnaf.push(0); + pos += 1; + } else { + wnaf.push(if window_val < width / 2 { + carry = 0; + window_val as i64 + } else { + carry = 1; + (window_val as i64).wrapping_sub(width as i64) + }); + wnaf.extend(std::iter::repeat(0).take(window - 1)); + pos += window; + } + } +} diff --git a/cryptography/kzg_multi_open/src/fk20/batch_toeplitz.rs b/cryptography/kzg_multi_open/src/fk20/batch_toeplitz.rs index 610b0566..07028ea2 100644 --- a/cryptography/kzg_multi_open/src/fk20/batch_toeplitz.rs +++ b/cryptography/kzg_multi_open/src/fk20/batch_toeplitz.rs @@ -1,6 +1,7 @@ use crate::fk20::toeplitz::{CirculantMatrix, ToeplitzMatrix}; use bls12_381::{ fixed_base_msm::{FixedBaseMSM, UsePrecomp}, + fixed_base_msm_blst::FixedBaseMultiMSMPrecompBLST, g1_batch_normalize, G1Point, G1Projective, }; use maybe_rayon::prelude::*; @@ -18,6 +19,7 @@ pub struct BatchToeplitzMatrixVecMul { /// we can do in a batch. batch_size: usize, precomputed_fft_vectors: Vec, + // precomputed_fft_vectors: FixedBaseMultiMSMPrecompBLST, // This is the length of the vector that we are multiplying the matrices with. // and subsequently will be the length of the final result of the matrix-vector multiplication. size_of_vector: usize, @@ -65,6 +67,8 @@ impl BatchToeplitzMatrixVecMul { .map(|v| FixedBaseMSM::new(v, use_precomp)) .collect(); + // let precomputed_table = FixedBaseMultiMSMPrecompBLST::new(transposed_msm_vectors, 8); + BatchToeplitzMatrixVecMul { size_of_vector, circulant_domain, @@ -102,12 +106,17 @@ impl BatchToeplitzMatrixVecMul { .collect(); let msm_scalars = transpose(col_ffts); + // let now = std::time::Instant::now(); let result: Vec<_> = (&self.precomputed_fft_vectors) .maybe_par_iter() .zip(msm_scalars) .map(|(points, scalars)| points.msm(scalars)) .collect(); + // dbg!(now.elapsed().as_micros()); + // let now = std::time::Instant::now(); + // let result = self.precomputed_fft_vectors.multi_msm(msm_scalars); + // dbg!(now.elapsed().as_micros()); // Once the aggregate circulant matrix-vector multiplication is done, we need to take the first half // of the result, as the second half are extra terms that were added due to the fact that the Toeplitz matrices // were embedded into circulant matrices. diff --git a/cryptography/kzg_multi_open/src/fk20/prover.rs b/cryptography/kzg_multi_open/src/fk20/prover.rs index 003214b2..84dd9fec 100644 --- a/cryptography/kzg_multi_open/src/fk20/prover.rs +++ b/cryptography/kzg_multi_open/src/fk20/prover.rs @@ -195,7 +195,6 @@ impl FK20Prover { let h_poly_commitments = compute_h_poly_commitments(&self.batch_toeplitz, polynomial.clone(), self.coset_size); let mut proofs = self.proof_domain.fft_g1(h_poly_commitments); - // Reverse bit order the set of proofs, so that the proofs line up with the // coset evaluations. reverse_bit_order(&mut proofs); diff --git a/cryptography/polynomial/Cargo.toml b/cryptography/polynomial/Cargo.toml index ae279477..adf9d5d7 100644 --- a/cryptography/polynomial/Cargo.toml +++ b/cryptography/polynomial/Cargo.toml @@ -12,7 +12,7 @@ repository = { workspace = true } [dependencies] bls12_381 = { workspace = true } - +hex = "*" [dev-dependencies] criterion = "0.5.1" rand = "0.8.4" diff --git a/cryptography/polynomial/src/domain.rs b/cryptography/polynomial/src/domain.rs index 1f231047..2324bfc9 100644 --- a/cryptography/polynomial/src/domain.rs +++ b/cryptography/polynomial/src/domain.rs @@ -108,7 +108,7 @@ impl Domain { // domain. polynomial.resize(self.size(), Scalar::ZERO); - fft_scalar(self.generator, &polynomial) + fft_scalar_new(self.generator, &polynomial) } /// Evaluates a polynomial at the points in the domain multiplied by a coset @@ -123,7 +123,7 @@ impl Domain { *point *= coset_scale; coset_scale *= self.coset_generator; } - fft_scalar(self.generator, &points) + fft_scalar_new(self.generator, &points) } /// Computes a DFT for the group elements(elliptic curve points) using the roots in the domain. @@ -136,6 +136,12 @@ impl Domain { points.resize(self.size(), G1Projective::identity()); fft_g1(self.generator, &points) } + pub fn fft_g1_new(&self, mut points: Vec) -> Vec { + // Pad the vector of points with zeroes, so that it is the same size as the + // domain. + points.resize(self.size(), G1Projective::identity()); + fft_g1_new(self.generator, &points) + } /// Computes an IDFT for the group elements(elliptic curve points) using the roots in the domain. pub fn ifft_g1(&self, points: Vec) -> Vec { @@ -158,7 +164,7 @@ impl Domain { // domain. points.resize(self.size(), G1Projective::identity()); - let ifft_g1 = fft_g1(self.generator_inv, &points); + let ifft_g1 = fft_g1_new(self.generator_inv, &points); // Truncate the result if a value of `n` was supplied. let mut ifft_g1 = match n { @@ -183,7 +189,7 @@ impl Domain { // domain. points.resize(self.size(), Scalar::ZERO); - let mut ifft_scalar = fft_scalar(self.generator_inv, &points); + let mut ifft_scalar = fft_scalar_new(self.generator_inv, &points); for element in ifft_scalar.iter_mut() { *element *= self.domain_size_inv @@ -234,6 +240,63 @@ fn fft_scalar(nth_root_of_unity: Scalar, points: &[Scalar]) -> Vec { evaluations } +pub fn fft_scalar_inplace(a: &mut [Scalar], nth_root_of_unity: Scalar) { + let n = a.len(); + let log_n = log2_pow2(n); + assert_eq!(n, 1 << log_n); + + // Bit-reversal permutation + for k in 0..n { + let rk = bitreverse(k as u32, log_n) as usize; + if k < rk { + a.swap(rk, k); + } + } + + // Main FFT computation + let mut m = 1; + for s in 0..log_n { + let w_m = nth_root_of_unity.pow(&[(n / (2 * m)) as u64]); + for k in (0..n).step_by(2 * m) { + let mut w = Scalar::ONE; + for j in 0..m { + let t = if w == Scalar::ONE { + a[k + j + m] + } else if w == -Scalar::ONE { + -a[k + j + m] + } else { + a[k + j + m] * w + }; + let u = a[k + j]; + a[k + j] = u + t; + a[k + j + m] = u - t; + w *= w_m; + } + } + m *= 2; + } +} + +fn bitreverse(mut n: u32, l: u32) -> u32 { + let mut r = 0; + for _ in 0..l { + r = (r << 1) | (n & 1); + n >>= 1; + } + r +} + +fn log2_pow2(n: usize) -> u32 { + n.trailing_zeros() +} + +// Helper function to create a new vector and perform FFT +pub fn fft_scalar_new(nth_root_of_unity: Scalar, points: &[Scalar]) -> Vec { + let mut a = points.to_vec(); + fft_scalar_inplace(&mut a, nth_root_of_unity); + a +} + /// Computes a DFT of the group elements(points) using powers of the roots of unity. /// /// Note: This is essentially multiple multi-scalar multiplications. @@ -264,6 +327,66 @@ fn fft_g1(nth_root_of_unity: Scalar, points: &[G1Projective]) -> Vec Vec { + let mut a = points.to_vec(); + // let now = std::time::Instant::now(); + fft_g1_inplace(&mut a, nth_root_of_unity); + // dbg!(now.elapsed().as_micros()); + a +} + +fn precompute_twiddle_factors(omega: &F, n: usize) -> Vec { + let log_n = log2_pow2(n); + (0..log_n) + .map(|s| omega.pow(&[(n / (1 << (s + 1))) as u64])) + .collect() +} /// Splits the list into two lists, one containing the even indexed elements /// and the other containing the odd indexed elements. @@ -284,6 +407,8 @@ fn take_even_odd(list: &[T]) -> (Vec, Vec) { #[cfg(test)] mod tests { + use bls12_381::G1Point; + use crate::monomial::poly_eval; use super::*; @@ -346,6 +471,323 @@ mod tests { assert_eq!(got_poly, polynomial); } + #[test] + fn test_fft_g1_bench() { + let domain = Domain::new(64); + + let points: Vec<_> = (0..64) + .map(|_| G1Projective::random(&mut rand::thread_rng())) + .collect(); + let now = std::time::Instant::now(); + let res = domain.fft_g1(points.clone()); + dbg!(now.elapsed().as_millis()); + let now = std::time::Instant::now(); + let res2 = domain.fft_g1_new(points); + dbg!(now.elapsed().as_millis()); + assert_eq!(res, res2) + } + + #[test] + fn test_fft_twiddle_factor_hamming_weight() { + let domain = Domain::new(64); + + let twiddle_factors = precompute_twiddle_factors(&domain.generator, 64); + let mut scaled_twiddles = Vec::new(); + let log_n = 6; + // Main FFT computation + let mut m = 1; + let n = 64; + for s in 0..log_n { + let w_m = twiddle_factors[s as usize]; + for k in (0..n).step_by(2 * m) { + let mut w = Scalar::ONE; + for j in 0..m { + w *= w_m; + scaled_twiddles.push(w); + } + } + m *= 2; + } + + dbg!(&domain.roots); + + let root = domain.roots.last().unwrap(); + let now = std::time::Instant::now(); + let res = G1Projective::generator() * root; + dbg!(now.elapsed().as_micros()); + let resAff = G1Point::from(res); + dbg!(hex::encode(resAff.to_compressed())); + let now = std::time::Instant::now(); + g1_projective_operations(G1Projective::generator()); + dbg!(now.elapsed().as_micros()); + } + fn g1_projective_operations(point: G1Projective) -> G1Projective { + let _1 = point; + let _10 = _1.double(); + let _11 = _1 + _10; + let _101 = _10 + _11; + let _110 = _1 + _101; + let _111 = _1 + _110; + let _1001 = _10 + _111; + let _1011 = _10 + _1001; + let _1101 = _10 + _1011; + let _1111 = _10 + _1101; + let _10001 = _10 + _1111; + let _10111 = _110 + _10001; + let _11001 = _10 + _10111; + let _11011 = _10 + _11001; + let _11101 = _10 + _11011; + let _110010 = _10111 + _11011; + + let i40 = ((_110010 + .double() + .double() + .double() + .double() + .double() + .double() + .double() + .double() + .double() + + _1101) + .double() + .double() + .double() + .double() + .double() + .double() + .double() + .double() + + _1001) + .double() + .double() + .double() + .double() + .double() + .double(); + let i56 = ((_11101 + i40).double().double().double().double().double() + _11) + .double() + .double() + .double() + .double() + .double() + .double() + .double() + .double() + + _1101; + let i78 = ((i56 + .double() + .double() + .double() + .double() + .double() + .double() + .double() + .double() + + _10111) + .double() + .double() + .double() + .double() + .double() + .double() + + _11101) + .double() + .double() + .double() + .double() + .double() + .double(); + let i91 = ((_10111 + i78).double().double().double() + _101) + .double() + .double() + .double() + .double() + .double() + .double() + .double() + + _11101; + let i110 = ((i91 + .double() + .double() + .double() + .double() + .double() + .double() + .double() + .double() + + _1001) + .double() + .double() + .double() + .double() + .double() + .double() + + _11001) + .double() + .double() + .double(); + let i125 = ((_111 + i110) + .double() + .double() + .double() + .double() + .double() + .double() + + _1101) + .double() + .double() + .double() + .double() + .double() + .double() + + _11101; + let i140 = ((_10 + i125).double().double().double().double() + _111) + .double() + .double() + .double() + .double() + .double() + .double() + .double() + .double() + + _11101; + let i160 = ((i140.double().double().double().double().double().double() + _11011) + .double() + .double() + .double() + + _11) + .double() + .double() + .double() + .double() + .double() + .double() + .double() + .double() + .double(); + let i175 = ((_11101 + i160).double().double().double().double().double() + _10111) + .double() + .double() + .double() + .double() + .double() + .double() + .double() + + _111; + let i195 = ((i175 + .double() + .double() + .double() + .double() + .double() + .double() + .double() + + _10111) + .double() + .double() + .double() + .double() + .double() + + _1111) + .double() + .double() + .double() + .double() + .double() + .double(); + let i207 = ((_1111 + i195).double().double().double().double() + _111) + .double() + .double() + .double() + .double() + .double() + + _111; + let i227 = ((i207.double().double().double().double().double().double() + _101) + .double() + .double() + .double() + .double() + .double() + .double() + + _111) + .double() + .double() + .double() + .double() + .double() + .double(); + let i241 = ((_1111 + i227) + .double() + .double() + .double() + .double() + .double() + .double() + .double() + .double() + + _10001) + .double() + .double() + .double() + + _101; + let i261 = ((i241.double().double().double().double().double() + _101) + .double() + .double() + .double() + .double() + .double() + + _11) + .double() + .double() + .double() + .double() + .double() + .double() + .double() + .double(); + let i277 = ((_101 + i261) + .double() + .double() + .double() + .double() + .double() + .double() + + _11) + .double() + .double() + .double() + .double() + .double() + .double() + .double() + + _1011; + let i295 = ((i277.double().double().double().double().double().double() + _10001) + .double() + .double() + .double() + .double() + .double() + .double() + + _11101) + .double() + .double() + .double() + .double(); + + ((_1101 + i295) + .double() + .double() + .double() + .double() + .double() + .double() + .double() + + _1011) + .double() + .double() + } + #[test] fn fft_g1_smoke_test() { fn naive_msm(points: &[G1Projective], scalars: &[Scalar]) -> G1Projective { diff --git a/eip7594/benches/benchmark.rs b/eip7594/benches/benchmark.rs index d038879a..2da2c5ba 100644 --- a/eip7594/benches/benchmark.rs +++ b/eip7594/benches/benchmark.rs @@ -30,23 +30,24 @@ fn dummy_commitment_cells_and_proofs() -> ( (commitment, ctx.compute_cells_and_kzg_proofs(&blob).unwrap()) } -const THREAD_COUNTS: [ThreadCount; 5] = [ - ThreadCount::Single, - ThreadCount::Multi(4), - ThreadCount::Multi(8), - ThreadCount::Multi(16), - ThreadCount::Multi(32), -]; +// const THREAD_COUNTS: [ThreadCount; 5] = [ +// ThreadCount::Single, +// ThreadCount::Multi(4), +// ThreadCount::Multi(8), +// ThreadCount::Multi(16), +// ThreadCount::Multi(32), +// ]; pub fn bench_compute_cells_and_kzg_proofs(c: &mut Criterion) { let trusted_setup = TrustedSetup::default(); let blob = dummy_blob(); - for num_threads in THREAD_COUNTS { - let ctx = DASContext::with_threads( + for num_threads in [1] { + let ctx = DASContext::new( &trusted_setup, - num_threads, + // num_threads, + // bls12_381::fixed_base_msm::UsePrecomp::No, bls12_381::fixed_base_msm::UsePrecomp::Yes { width: 8 }, ); c.bench_function( @@ -59,42 +60,42 @@ pub fn bench_compute_cells_and_kzg_proofs(c: &mut Criterion) { } } -pub fn bench_recover_cells_and_compute_kzg_proofs(c: &mut Criterion) { - let trusted_setup = TrustedSetup::default(); - - let (_, (cells, _)) = dummy_commitment_cells_and_proofs(); - let cell_indices: Vec = (0..cells.len()).map(|x| x as u64).collect(); - - // Worse case is when half of the cells are missing - let half_cell_indices = &cell_indices[..CELLS_PER_EXT_BLOB / 2]; - let half_cells = &cells[..CELLS_PER_EXT_BLOB / 2]; - let half_cells = half_cells - .into_iter() - .map(|cell| cell.as_ref()) - .collect::>(); - - for num_threads in THREAD_COUNTS { - let ctx = DASContext::with_threads( - &trusted_setup, - num_threads, - bls12_381::fixed_base_msm::UsePrecomp::Yes { width: 8 }, - ); - c.bench_function( - &format!( - "worse-case recover_cells_and_kzg_proofs - NUM_THREADS: {:?}", - num_threads - ), - |b| { - b.iter(|| { - ctx.recover_cells_and_kzg_proofs( - half_cell_indices.to_vec(), - half_cells.to_vec(), - ) - }) - }, - ); - } -} +// pub fn bench_recover_cells_and_compute_kzg_proofs(c: &mut Criterion) { +// let trusted_setup = TrustedSetup::default(); + +// let (_, (cells, _)) = dummy_commitment_cells_and_proofs(); +// let cell_indices: Vec = (0..cells.len()).map(|x| x as u64).collect(); + +// // Worse case is when half of the cells are missing +// let half_cell_indices = &cell_indices[..CELLS_PER_EXT_BLOB / 2]; +// let half_cells = &cells[..CELLS_PER_EXT_BLOB / 2]; +// let half_cells = half_cells +// .into_iter() +// .map(|cell| cell.as_ref()) +// .collect::>(); + +// for num_threads in THREAD_COUNTS { +// let ctx = DASContext::with_threads( +// &trusted_setup, +// num_threads, +// bls12_381::fixed_base_msm::UsePrecomp::Yes { width: 8 }, +// ); +// c.bench_function( +// &format!( +// "worse-case recover_cells_and_kzg_proofs - NUM_THREADS: {:?}", +// num_threads +// ), +// |b| { +// b.iter(|| { +// ctx.recover_cells_and_kzg_proofs( +// half_cell_indices.to_vec(), +// half_cells.to_vec(), +// ) +// }) +// }, +// ); +// } +// } pub fn bench_verify_cell_kzg_proof_batch(c: &mut Criterion) { let trusted_setup = TrustedSetup::default(); @@ -106,10 +107,9 @@ pub fn bench_verify_cell_kzg_proof_batch(c: &mut Criterion) { let cell_refs: Vec = cells.iter().map(|cell| cell.as_ref()).collect(); let proof_refs: Vec = proofs.iter().map(|proof| proof).collect(); - for num_threads in THREAD_COUNTS { - let ctx = DASContext::with_threads( + for num_threads in [1] { + let ctx = DASContext::new( &trusted_setup, - num_threads, bls12_381::fixed_base_msm::UsePrecomp::Yes { width: 8 }, ); c.bench_function( @@ -131,25 +131,25 @@ pub fn bench_verify_cell_kzg_proof_batch(c: &mut Criterion) { } } -pub fn bench_init_context(c: &mut Criterion) { - const NUM_THREADS: ThreadCount = ThreadCount::Single; - c.bench_function(&format!("Initialize context"), |b| { - b.iter(|| { - let trusted_setup = TrustedSetup::default(); - DASContext::with_threads( - &trusted_setup, - NUM_THREADS, - bls12_381::fixed_base_msm::UsePrecomp::Yes { width: 8 }, - ) - }) - }); -} +// pub fn bench_init_context(c: &mut Criterion) { +// const NUM_THREADS: ThreadCount = ThreadCount::Single; +// c.bench_function(&format!("Initialize context"), |b| { +// b.iter(|| { +// let trusted_setup = TrustedSetup::default(); +// DASContext::with_threads( +// &trusted_setup, +// NUM_THREADS, +// bls12_381::fixed_base_msm::UsePrecomp::Yes { width: 8 }, +// ) +// }) +// }); +// } criterion_group!( benches, - bench_init_context, + // bench_init_context, bench_compute_cells_and_kzg_proofs, - bench_recover_cells_and_compute_kzg_proofs, - bench_verify_cell_kzg_proof_batch + // bench_recover_cells_and_compute_kzg_proofs, + // bench_verify_cell_kzg_proof_batch ); criterion_main!(benches);