Skip to content

Commit

Permalink
feat: Replace blst msm method with a Rust native method (#273)
Browse files Browse the repository at this point in the history
* add rust implementation

* rename blst msm struct and replace it with rust one

* add benchmarks for blst and rust version

* fix: comment

* add comment on where code was taken from
  • Loading branch information
kevaundray authored Sep 24, 2024
1 parent c4c1fb8 commit b4ef4af
Show file tree
Hide file tree
Showing 5 changed files with 261 additions and 13 deletions.
15 changes: 11 additions & 4 deletions cryptography/bls12_381/benches/benchmark.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate_crypto_internal_eth_kzg_bls12_381::{
batch_inversion,
ff::Field,
fixed_base_msm::{FixedBaseMSM, UsePrecomp},
fixed_base_msm::FixedBaseMSMPrecompBLST,
fixed_base_msm_window::FixedBaseMSMPrecompWindow,
g1_batch_normalize, g2_batch_normalize,
group::Group,
lincomb::{g1_lincomb, g1_lincomb_unsafe, g2_lincomb, g2_lincomb_unsafe},
Expand All @@ -28,12 +29,18 @@ pub fn fixed_base_msm(c: &mut Criterion) {
.into_iter()
.map(|p| p.into())
.collect();
let fbm = FixedBaseMSM::new(generators, UsePrecomp::Yes { width: 8 });
let scalars: Vec<_> = random_scalars(length);

c.bench_function("bls12_381 fixed_base_msm length=64 width=8", |b| {
let fbm = FixedBaseMSMPrecompBLST::new(generators.clone(), 8);
let scalars: Vec<_> = random_scalars(length);
c.bench_function("bls12_381 fixed_base_msm length=64 width=8 (blst)", |b| {
b.iter(|| fbm.msm(scalars.clone()))
});

let fbm = FixedBaseMSMPrecompWindow::new(&generators, 8);
let scalars: Vec<_> = random_scalars(length);
c.bench_function("bls12_381 fixed_base_msm length=64 width=8 (rust)", |b| {
b.iter(|| fbm.msm(&scalars))
});
}

pub fn bench_msm(c: &mut Criterion) {
Expand Down
96 changes: 96 additions & 0 deletions cryptography/bls12_381/src/booth_encoding.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
use std::ops::Neg;

// Code was taken from: https://github.com/privacy-scaling-explorations/halo2curves/blob/b753a832e92d5c86c5c997327a9cf9de86a18851/src/msm.rs#L13
pub fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 {
// Booth encoding:
// * step by `window` size
// * slice by size of `window + 1``
// * each window overlap by 1 bit
// * append a zero bit to the least significant end
// Indexing rule for example window size 3 where we slice by 4 bits:
// `[0, +1, +1, +2, +2, +3, +3, +4, -4, -3, -3 -2, -2, -1, -1, 0]``
// So we can reduce the bucket size without preprocessing scalars
// and remembering them as in classic signed digit encoding

let skip_bits = (window_index * window_size).saturating_sub(1);
let skip_bytes = skip_bits / 8;

// fill into a u32
let mut v: [u8; 4] = [0; 4];
for (dst, src) in v.iter_mut().zip(el.iter().skip(skip_bytes)) {
*dst = *src
}
let mut tmp = u32::from_le_bytes(v);

// pad with one 0 if slicing the least significant window
if window_index == 0 {
tmp <<= 1;
}

// remove further bits
tmp >>= skip_bits - (skip_bytes * 8);
// apply the booth window
tmp &= (1 << (window_size + 1)) - 1;

let sign = tmp & (1 << window_size) == 0;

// div ceil by 2
tmp = (tmp + 1) >> 1;

// find the booth action index
if sign {
tmp as i32
} else {
((!(tmp - 1) & ((1 << window_size) - 1)) as i32).neg()
}
}

#[cfg(test)]
mod tests {
use std::ops::Neg;

use super::get_booth_index;
use crate::G1Point;
use blstrs::{G1Projective, Scalar};
use ff::{Field, PrimeField};

#[test]
fn smoke_scalar_mul() {
use group::prime::PrimeCurveAffine;
let gen = G1Point::generator();
let s = -Scalar::ONE;

let res = gen * s;

let got = mul(&s, &gen, 4);

assert_eq!(G1Point::from(res), got)
}

fn mul(scalar: &Scalar, point: &G1Point, window: usize) -> G1Point {
let u = scalar.to_bytes_le();
let n = Scalar::NUM_BITS as usize / window + 1;

let table = (0..=1 << (window - 1))
.map(|i| point * Scalar::from(i as u64))
.collect::<Vec<_>>();

let mut acc: G1Projective = G1Point::default().into();
for i in (0..n).rev() {
for _ in 0..window {
acc = acc + acc;
}

let idx = get_booth_index(i as usize, window, u.as_ref());

if idx.is_negative() {
acc += table[idx.unsigned_abs() as usize].neg();
}
if idx.is_positive() {
acc += table[idx.unsigned_abs() as usize];
}
}

acc.into()
}
}
18 changes: 9 additions & 9 deletions cryptography/bls12_381/src/fixed_base_msm.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use crate::{G1Projective, Scalar};
use crate::{fixed_base_msm_window::FixedBaseMSMPrecompWindow, G1Projective, Scalar};
use blstrs::{Fp, G1Affine};

/// FixedBaseMSMPrecomp computes a multi scalar multiplication using pre-computations.
///
/// It uses batch addition to amortize the cost of adding multiple points together.
#[derive(Debug)]
pub struct FixedBaseMSMPrecomp {
pub struct FixedBaseMSMPrecompBLST {
table: Vec<blst::blst_p1_affine>,
wbits: usize,
num_points: usize,
Expand All @@ -27,23 +27,23 @@ pub enum UsePrecomp {
/// of memory.
#[derive(Debug)]
pub enum FixedBaseMSM {
Precomp(FixedBaseMSMPrecomp),
Precomp(FixedBaseMSMPrecompWindow),
NoPrecomp(Vec<G1Affine>),
}

impl FixedBaseMSM {
pub fn new(generators: Vec<G1Affine>, use_precomp: UsePrecomp) -> Self {
match use_precomp {
UsePrecomp::Yes { width } => {
FixedBaseMSM::Precomp(FixedBaseMSMPrecomp::new(generators, width))
FixedBaseMSM::Precomp(FixedBaseMSMPrecompWindow::new(&generators, width))
}
UsePrecomp::No => FixedBaseMSM::NoPrecomp(generators),
}
}

pub fn msm(&self, scalars: Vec<Scalar>) -> G1Projective {
match self {
FixedBaseMSM::Precomp(precomp) => precomp.msm(scalars),
FixedBaseMSM::Precomp(precomp) => precomp.msm(&scalars),
FixedBaseMSM::NoPrecomp(generators) => {
use crate::lincomb::g1_lincomb;
g1_lincomb(generators, &scalars)
Expand All @@ -53,7 +53,7 @@ impl FixedBaseMSM {
}
}

impl FixedBaseMSMPrecomp {
impl FixedBaseMSMPrecompBLST {
pub fn new(generators_affine: Vec<G1Affine>, wbits: usize) -> Self {
let num_points = generators_affine.len();
let table_size_bytes =
Expand All @@ -74,7 +74,7 @@ impl FixedBaseMSMPrecomp {

let scratch_space_size = unsafe { blst::blst_p1s_mult_wbits_scratch_sizeof(num_points) };

FixedBaseMSMPrecomp {
FixedBaseMSMPrecompBLST {
table,
wbits,
num_points,
Expand Down Expand Up @@ -120,7 +120,7 @@ impl FixedBaseMSMPrecomp {

#[cfg(test)]
mod tests {
use super::{FixedBaseMSMPrecomp, UsePrecomp};
use super::{FixedBaseMSMPrecompBLST, UsePrecomp};
use crate::{fixed_base_msm::FixedBaseMSM, G1Projective, Scalar};
use ff::Field;
use group::Group;
Expand Down Expand Up @@ -158,7 +158,7 @@ mod tests {
let generators: Vec<_> = (0..length)
.map(|_| G1Projective::random(&mut rand::thread_rng()).into())
.collect();
let fbm = FixedBaseMSMPrecomp::new(generators, 8);
let fbm = FixedBaseMSMPrecompBLST::new(generators, 8);
for val in fbm.table.into_iter() {
let is_inf =
unsafe { blst::blst_p1_affine_is_inf(&val as *const blst::blst_p1_affine) };
Expand Down
143 changes: 143 additions & 0 deletions cryptography/bls12_381/src/fixed_base_msm_window.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
use crate::{
batch_add::multi_batch_addition_binary_tree_stride, booth_encoding::get_booth_index,
g1_batch_normalize, G1Projective, Scalar,
};
use blstrs::G1Affine;
use ff::PrimeField;
use group::Group;

// Note: This is the same strategy that blst uses
#[derive(Debug)]
pub struct FixedBaseMSMPrecompWindow {
table: Vec<Vec<G1Affine>>,
wbits: usize,
}

impl FixedBaseMSMPrecompWindow {
pub fn new(points: &[G1Affine], wbits: usize) -> Self {
// For every point `P`, wbits indicates that we should compute
// 1 * P, ..., (2^{wbits} - 1) * P
//
// The total amount of memory is roughly (numPoints * 2^{wbits} - 1)
// where each point is 64 bytes.
//
let precomputed_points: Vec<_> = points
.into_iter()
.map(|point| Self::precompute_points(wbits, *point))
.collect();

Self {
table: precomputed_points,
wbits,
}
}
// Given a point, we precompute P,..., (2^{w-1}-1) * P
fn precompute_points(wbits: usize, point: G1Affine) -> Vec<G1Affine> {
let mut lookup_table = Vec::with_capacity(1 << (wbits - 1));

// Convert to projective for faster operations
let mut current = G1Projective::from(point);

// Compute and store multiples
for _ in 0..(1 << (wbits - 1)) {
lookup_table.push(current);
current += point;
}

g1_batch_normalize(&lookup_table)
}

pub fn msm(&self, scalars: &[Scalar]) -> G1Projective {
let scalars_bytes: Vec<_> = scalars.iter().map(|a| a.to_bytes_le()).collect();
let number_of_windows = Scalar::NUM_BITS as usize / self.wbits + 1;

let mut windows_of_points = vec![Vec::with_capacity(scalars.len()); number_of_windows];

for window_idx in 0..number_of_windows {
for (scalar_idx, scalar_bytes) in scalars_bytes.iter().enumerate() {
let sub_table = &self.table[scalar_idx];
let point_idx = get_booth_index(window_idx, self.wbits, scalar_bytes.as_ref());

if point_idx == 0 {
continue;
}
let sign = point_idx.is_positive();
let point_idx = point_idx.unsigned_abs() as usize - 1;
let mut point = sub_table[point_idx];
if !sign {
point = -point;
}

windows_of_points[window_idx].push(point);
}
}

let accumulated_points = multi_batch_addition_binary_tree_stride(windows_of_points);

// Now accumulate the windows by doubling wbits times
let mut result: G1Projective = *accumulated_points.last().unwrap();
for point in accumulated_points.into_iter().rev().skip(1) {
// Double the result 'wbits' times
for _ in 0..self.wbits {
result = result.double();
}
// Add the accumulated point for this window
result += point;
}

result
}
}

#[cfg(test)]
mod tests {
use super::*;
use ff::Field;
use group::prime::PrimeCurveAffine;

#[test]
fn precomp_lookup_table() {
use group::Group;
let lookup_table = FixedBaseMSMPrecompWindow::precompute_points(7, G1Affine::generator());

for i in 1..lookup_table.len() {
let expected = G1Projective::generator() * Scalar::from((i + 1) as u64);
assert_eq!(lookup_table[i], expected.into(),)
}
}

#[test]
fn msm_blst_precomp() {
let length = 64;
let generators: Vec<_> = (0..length)
.map(|_| G1Projective::random(&mut rand::thread_rng()).into())
.collect();
let scalars: Vec<_> = (0..length)
.map(|_| Scalar::random(&mut rand::thread_rng()))
.collect();

let res = crate::lincomb::g1_lincomb(&generators, &scalars)
.expect("number of generators and number of scalars is equal");

let fbm = FixedBaseMSMPrecompWindow::new(&generators, 7);
let result = fbm.msm(&scalars);

assert_eq!(res, result);
}

#[test]
fn bench_window_sizes_msm() {
let length = 64;
let generators: Vec<_> = (0..length)
.map(|_| G1Projective::random(&mut rand::thread_rng()).into())
.collect();
let scalars: Vec<_> = (0..length)
.map(|_| Scalar::random(&mut rand::thread_rng()))
.collect();

for i in 2..=14 {
let fbm = FixedBaseMSMPrecompWindow::new(&generators, i);
fbm.msm(&scalars);
}
}
}
2 changes: 2 additions & 0 deletions cryptography/bls12_381/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
pub mod batch_add;
pub mod batch_inversion;
mod booth_encoding;
pub mod fixed_base_msm;
pub mod fixed_base_msm_window;
pub mod lincomb;

// Re-exporting the blstrs crate
Expand Down

0 comments on commit b4ef4af

Please sign in to comment.