From b4ef4afa16ebba2b1e071eda1c6f70542d7b7989 Mon Sep 17 00:00:00 2001 From: kevaundray Date: Tue, 24 Sep 2024 21:12:36 +0100 Subject: [PATCH] feat: Replace blst msm method with a Rust native method (#273) * add rust implementation * rename blst msm struct and replace it with rust one * add benchmarks for blst and rust version * fix: comment * add comment on where code was taken from --- cryptography/bls12_381/benches/benchmark.rs | 15 +- cryptography/bls12_381/src/booth_encoding.rs | 96 ++++++++++++ cryptography/bls12_381/src/fixed_base_msm.rs | 18 +-- .../bls12_381/src/fixed_base_msm_window.rs | 143 ++++++++++++++++++ cryptography/bls12_381/src/lib.rs | 2 + 5 files changed, 261 insertions(+), 13 deletions(-) create mode 100644 cryptography/bls12_381/src/booth_encoding.rs create mode 100644 cryptography/bls12_381/src/fixed_base_msm_window.rs diff --git a/cryptography/bls12_381/benches/benchmark.rs b/cryptography/bls12_381/benches/benchmark.rs index 7db62c3b..80974a20 100644 --- a/cryptography/bls12_381/benches/benchmark.rs +++ b/cryptography/bls12_381/benches/benchmark.rs @@ -1,7 +1,8 @@ use crate_crypto_internal_eth_kzg_bls12_381::{ batch_inversion, ff::Field, - fixed_base_msm::{FixedBaseMSM, UsePrecomp}, + fixed_base_msm::FixedBaseMSMPrecompBLST, + fixed_base_msm_window::FixedBaseMSMPrecompWindow, g1_batch_normalize, g2_batch_normalize, group::Group, lincomb::{g1_lincomb, g1_lincomb_unsafe, g2_lincomb, g2_lincomb_unsafe}, @@ -28,12 +29,18 @@ pub fn fixed_base_msm(c: &mut Criterion) { .into_iter() .map(|p| p.into()) .collect(); - let fbm = FixedBaseMSM::new(generators, UsePrecomp::Yes { width: 8 }); - let scalars: Vec<_> = random_scalars(length); - c.bench_function("bls12_381 fixed_base_msm length=64 width=8", |b| { + let fbm = FixedBaseMSMPrecompBLST::new(generators.clone(), 8); + let scalars: Vec<_> = random_scalars(length); + c.bench_function("bls12_381 fixed_base_msm length=64 width=8 (blst)", |b| { b.iter(|| fbm.msm(scalars.clone())) }); + + let fbm = FixedBaseMSMPrecompWindow::new(&generators, 8); + let scalars: Vec<_> = random_scalars(length); + c.bench_function("bls12_381 fixed_base_msm length=64 width=8 (rust)", |b| { + b.iter(|| fbm.msm(&scalars)) + }); } pub fn bench_msm(c: &mut Criterion) { diff --git a/cryptography/bls12_381/src/booth_encoding.rs b/cryptography/bls12_381/src/booth_encoding.rs new file mode 100644 index 00000000..ff459e18 --- /dev/null +++ b/cryptography/bls12_381/src/booth_encoding.rs @@ -0,0 +1,96 @@ +use std::ops::Neg; + +// Code was taken from: https://github.com/privacy-scaling-explorations/halo2curves/blob/b753a832e92d5c86c5c997327a9cf9de86a18851/src/msm.rs#L13 +pub fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 { + // Booth encoding: + // * step by `window` size + // * slice by size of `window + 1`` + // * each window overlap by 1 bit + // * append a zero bit to the least significant end + // Indexing rule for example window size 3 where we slice by 4 bits: + // `[0, +1, +1, +2, +2, +3, +3, +4, -4, -3, -3 -2, -2, -1, -1, 0]`` + // So we can reduce the bucket size without preprocessing scalars + // and remembering them as in classic signed digit encoding + + let skip_bits = (window_index * window_size).saturating_sub(1); + let skip_bytes = skip_bits / 8; + + // fill into a u32 + let mut v: [u8; 4] = [0; 4]; + for (dst, src) in v.iter_mut().zip(el.iter().skip(skip_bytes)) { + *dst = *src + } + let mut tmp = u32::from_le_bytes(v); + + // pad with one 0 if slicing the least significant window + if window_index == 0 { + tmp <<= 1; + } + + // remove further bits + tmp >>= skip_bits - (skip_bytes * 8); + // apply the booth window + tmp &= (1 << (window_size + 1)) - 1; + + let sign = tmp & (1 << window_size) == 0; + + // div ceil by 2 + tmp = (tmp + 1) >> 1; + + // find the booth action index + if sign { + tmp as i32 + } else { + ((!(tmp - 1) & ((1 << window_size) - 1)) as i32).neg() + } +} + +#[cfg(test)] +mod tests { + use std::ops::Neg; + + use super::get_booth_index; + use crate::G1Point; + use blstrs::{G1Projective, Scalar}; + use ff::{Field, PrimeField}; + + #[test] + fn smoke_scalar_mul() { + use group::prime::PrimeCurveAffine; + let gen = G1Point::generator(); + let s = -Scalar::ONE; + + let res = gen * s; + + let got = mul(&s, &gen, 4); + + assert_eq!(G1Point::from(res), got) + } + + fn mul(scalar: &Scalar, point: &G1Point, window: usize) -> G1Point { + let u = scalar.to_bytes_le(); + let n = Scalar::NUM_BITS as usize / window + 1; + + let table = (0..=1 << (window - 1)) + .map(|i| point * Scalar::from(i as u64)) + .collect::>(); + + let mut acc: G1Projective = G1Point::default().into(); + for i in (0..n).rev() { + for _ in 0..window { + acc = acc + acc; + } + + let idx = get_booth_index(i as usize, window, u.as_ref()); + + if idx.is_negative() { + acc += table[idx.unsigned_abs() as usize].neg(); + } + if idx.is_positive() { + acc += table[idx.unsigned_abs() as usize]; + } + } + + acc.into() + } +} diff --git a/cryptography/bls12_381/src/fixed_base_msm.rs b/cryptography/bls12_381/src/fixed_base_msm.rs index 491ffc01..705f3556 100644 --- a/cryptography/bls12_381/src/fixed_base_msm.rs +++ b/cryptography/bls12_381/src/fixed_base_msm.rs @@ -1,11 +1,11 @@ -use crate::{G1Projective, Scalar}; +use crate::{fixed_base_msm_window::FixedBaseMSMPrecompWindow, G1Projective, Scalar}; use blstrs::{Fp, G1Affine}; /// FixedBaseMSMPrecomp computes a multi scalar multiplication using pre-computations. /// /// It uses batch addition to amortize the cost of adding multiple points together. #[derive(Debug)] -pub struct FixedBaseMSMPrecomp { +pub struct FixedBaseMSMPrecompBLST { table: Vec, wbits: usize, num_points: usize, @@ -27,7 +27,7 @@ pub enum UsePrecomp { /// of memory. #[derive(Debug)] pub enum FixedBaseMSM { - Precomp(FixedBaseMSMPrecomp), + Precomp(FixedBaseMSMPrecompWindow), NoPrecomp(Vec), } @@ -35,7 +35,7 @@ impl FixedBaseMSM { pub fn new(generators: Vec, use_precomp: UsePrecomp) -> Self { match use_precomp { UsePrecomp::Yes { width } => { - FixedBaseMSM::Precomp(FixedBaseMSMPrecomp::new(generators, width)) + FixedBaseMSM::Precomp(FixedBaseMSMPrecompWindow::new(&generators, width)) } UsePrecomp::No => FixedBaseMSM::NoPrecomp(generators), } @@ -43,7 +43,7 @@ impl FixedBaseMSM { pub fn msm(&self, scalars: Vec) -> G1Projective { match self { - FixedBaseMSM::Precomp(precomp) => precomp.msm(scalars), + FixedBaseMSM::Precomp(precomp) => precomp.msm(&scalars), FixedBaseMSM::NoPrecomp(generators) => { use crate::lincomb::g1_lincomb; g1_lincomb(generators, &scalars) @@ -53,7 +53,7 @@ impl FixedBaseMSM { } } -impl FixedBaseMSMPrecomp { +impl FixedBaseMSMPrecompBLST { pub fn new(generators_affine: Vec, wbits: usize) -> Self { let num_points = generators_affine.len(); let table_size_bytes = @@ -74,7 +74,7 @@ impl FixedBaseMSMPrecomp { let scratch_space_size = unsafe { blst::blst_p1s_mult_wbits_scratch_sizeof(num_points) }; - FixedBaseMSMPrecomp { + FixedBaseMSMPrecompBLST { table, wbits, num_points, @@ -120,7 +120,7 @@ impl FixedBaseMSMPrecomp { #[cfg(test)] mod tests { - use super::{FixedBaseMSMPrecomp, UsePrecomp}; + use super::{FixedBaseMSMPrecompBLST, UsePrecomp}; use crate::{fixed_base_msm::FixedBaseMSM, G1Projective, Scalar}; use ff::Field; use group::Group; @@ -158,7 +158,7 @@ mod tests { let generators: Vec<_> = (0..length) .map(|_| G1Projective::random(&mut rand::thread_rng()).into()) .collect(); - let fbm = FixedBaseMSMPrecomp::new(generators, 8); + let fbm = FixedBaseMSMPrecompBLST::new(generators, 8); for val in fbm.table.into_iter() { let is_inf = unsafe { blst::blst_p1_affine_is_inf(&val as *const blst::blst_p1_affine) }; diff --git a/cryptography/bls12_381/src/fixed_base_msm_window.rs b/cryptography/bls12_381/src/fixed_base_msm_window.rs new file mode 100644 index 00000000..980ef96d --- /dev/null +++ b/cryptography/bls12_381/src/fixed_base_msm_window.rs @@ -0,0 +1,143 @@ +use crate::{ + batch_add::multi_batch_addition_binary_tree_stride, booth_encoding::get_booth_index, + g1_batch_normalize, G1Projective, Scalar, +}; +use blstrs::G1Affine; +use ff::PrimeField; +use group::Group; + +// Note: This is the same strategy that blst uses +#[derive(Debug)] +pub struct FixedBaseMSMPrecompWindow { + table: Vec>, + wbits: usize, +} + +impl FixedBaseMSMPrecompWindow { + pub fn new(points: &[G1Affine], wbits: usize) -> Self { + // For every point `P`, wbits indicates that we should compute + // 1 * P, ..., (2^{wbits} - 1) * P + // + // The total amount of memory is roughly (numPoints * 2^{wbits} - 1) + // where each point is 64 bytes. + // + let precomputed_points: Vec<_> = points + .into_iter() + .map(|point| Self::precompute_points(wbits, *point)) + .collect(); + + Self { + table: precomputed_points, + wbits, + } + } + // Given a point, we precompute P,..., (2^{w-1}-1) * P + fn precompute_points(wbits: usize, point: G1Affine) -> Vec { + let mut lookup_table = Vec::with_capacity(1 << (wbits - 1)); + + // Convert to projective for faster operations + let mut current = G1Projective::from(point); + + // Compute and store multiples + for _ in 0..(1 << (wbits - 1)) { + lookup_table.push(current); + current += point; + } + + g1_batch_normalize(&lookup_table) + } + + pub fn msm(&self, scalars: &[Scalar]) -> G1Projective { + let scalars_bytes: Vec<_> = scalars.iter().map(|a| a.to_bytes_le()).collect(); + let number_of_windows = Scalar::NUM_BITS as usize / self.wbits + 1; + + let mut windows_of_points = vec![Vec::with_capacity(scalars.len()); number_of_windows]; + + for window_idx in 0..number_of_windows { + for (scalar_idx, scalar_bytes) in scalars_bytes.iter().enumerate() { + let sub_table = &self.table[scalar_idx]; + let point_idx = get_booth_index(window_idx, self.wbits, scalar_bytes.as_ref()); + + if point_idx == 0 { + continue; + } + let sign = point_idx.is_positive(); + let point_idx = point_idx.unsigned_abs() as usize - 1; + let mut point = sub_table[point_idx]; + if !sign { + point = -point; + } + + windows_of_points[window_idx].push(point); + } + } + + let accumulated_points = multi_batch_addition_binary_tree_stride(windows_of_points); + + // Now accumulate the windows by doubling wbits times + let mut result: G1Projective = *accumulated_points.last().unwrap(); + for point in accumulated_points.into_iter().rev().skip(1) { + // Double the result 'wbits' times + for _ in 0..self.wbits { + result = result.double(); + } + // Add the accumulated point for this window + result += point; + } + + result + } +} + +#[cfg(test)] +mod tests { + use super::*; + use ff::Field; + use group::prime::PrimeCurveAffine; + + #[test] + fn precomp_lookup_table() { + use group::Group; + let lookup_table = FixedBaseMSMPrecompWindow::precompute_points(7, G1Affine::generator()); + + for i in 1..lookup_table.len() { + let expected = G1Projective::generator() * Scalar::from((i + 1) as u64); + assert_eq!(lookup_table[i], expected.into(),) + } + } + + #[test] + fn msm_blst_precomp() { + let length = 64; + let generators: Vec<_> = (0..length) + .map(|_| G1Projective::random(&mut rand::thread_rng()).into()) + .collect(); + let scalars: Vec<_> = (0..length) + .map(|_| Scalar::random(&mut rand::thread_rng())) + .collect(); + + let res = crate::lincomb::g1_lincomb(&generators, &scalars) + .expect("number of generators and number of scalars is equal"); + + let fbm = FixedBaseMSMPrecompWindow::new(&generators, 7); + let result = fbm.msm(&scalars); + + assert_eq!(res, result); + } + + #[test] + fn bench_window_sizes_msm() { + let length = 64; + let generators: Vec<_> = (0..length) + .map(|_| G1Projective::random(&mut rand::thread_rng()).into()) + .collect(); + let scalars: Vec<_> = (0..length) + .map(|_| Scalar::random(&mut rand::thread_rng())) + .collect(); + + for i in 2..=14 { + let fbm = FixedBaseMSMPrecompWindow::new(&generators, i); + fbm.msm(&scalars); + } + } +} diff --git a/cryptography/bls12_381/src/lib.rs b/cryptography/bls12_381/src/lib.rs index dab3df9d..471a1c78 100644 --- a/cryptography/bls12_381/src/lib.rs +++ b/cryptography/bls12_381/src/lib.rs @@ -1,6 +1,8 @@ pub mod batch_add; pub mod batch_inversion; +mod booth_encoding; pub mod fixed_base_msm; +pub mod fixed_base_msm_window; pub mod lincomb; // Re-exporting the blstrs crate