-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Replace blst msm method with a Rust native method (#273)
* add rust implementation * rename blst msm struct and replace it with rust one * add benchmarks for blst and rust version * fix: comment * add comment on where code was taken from
- Loading branch information
1 parent
c4c1fb8
commit b4ef4af
Showing
5 changed files
with
261 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
use std::ops::Neg; | ||
|
||
// Code was taken from: https://github.com/privacy-scaling-explorations/halo2curves/blob/b753a832e92d5c86c5c997327a9cf9de86a18851/src/msm.rs#L13 | ||
pub fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 { | ||
// Booth encoding: | ||
// * step by `window` size | ||
// * slice by size of `window + 1`` | ||
// * each window overlap by 1 bit | ||
// * append a zero bit to the least significant end | ||
// Indexing rule for example window size 3 where we slice by 4 bits: | ||
// `[0, +1, +1, +2, +2, +3, +3, +4, -4, -3, -3 -2, -2, -1, -1, 0]`` | ||
// So we can reduce the bucket size without preprocessing scalars | ||
// and remembering them as in classic signed digit encoding | ||
|
||
let skip_bits = (window_index * window_size).saturating_sub(1); | ||
let skip_bytes = skip_bits / 8; | ||
|
||
// fill into a u32 | ||
let mut v: [u8; 4] = [0; 4]; | ||
for (dst, src) in v.iter_mut().zip(el.iter().skip(skip_bytes)) { | ||
*dst = *src | ||
} | ||
let mut tmp = u32::from_le_bytes(v); | ||
|
||
// pad with one 0 if slicing the least significant window | ||
if window_index == 0 { | ||
tmp <<= 1; | ||
} | ||
|
||
// remove further bits | ||
tmp >>= skip_bits - (skip_bytes * 8); | ||
// apply the booth window | ||
tmp &= (1 << (window_size + 1)) - 1; | ||
|
||
let sign = tmp & (1 << window_size) == 0; | ||
|
||
// div ceil by 2 | ||
tmp = (tmp + 1) >> 1; | ||
|
||
// find the booth action index | ||
if sign { | ||
tmp as i32 | ||
} else { | ||
((!(tmp - 1) & ((1 << window_size) - 1)) as i32).neg() | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use std::ops::Neg; | ||
|
||
use super::get_booth_index; | ||
use crate::G1Point; | ||
use blstrs::{G1Projective, Scalar}; | ||
use ff::{Field, PrimeField}; | ||
|
||
#[test] | ||
fn smoke_scalar_mul() { | ||
use group::prime::PrimeCurveAffine; | ||
let gen = G1Point::generator(); | ||
let s = -Scalar::ONE; | ||
|
||
let res = gen * s; | ||
|
||
let got = mul(&s, &gen, 4); | ||
|
||
assert_eq!(G1Point::from(res), got) | ||
} | ||
|
||
fn mul(scalar: &Scalar, point: &G1Point, window: usize) -> G1Point { | ||
let u = scalar.to_bytes_le(); | ||
let n = Scalar::NUM_BITS as usize / window + 1; | ||
|
||
let table = (0..=1 << (window - 1)) | ||
.map(|i| point * Scalar::from(i as u64)) | ||
.collect::<Vec<_>>(); | ||
|
||
let mut acc: G1Projective = G1Point::default().into(); | ||
for i in (0..n).rev() { | ||
for _ in 0..window { | ||
acc = acc + acc; | ||
} | ||
|
||
let idx = get_booth_index(i as usize, window, u.as_ref()); | ||
|
||
if idx.is_negative() { | ||
acc += table[idx.unsigned_abs() as usize].neg(); | ||
} | ||
if idx.is_positive() { | ||
acc += table[idx.unsigned_abs() as usize]; | ||
} | ||
} | ||
|
||
acc.into() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
use crate::{ | ||
batch_add::multi_batch_addition_binary_tree_stride, booth_encoding::get_booth_index, | ||
g1_batch_normalize, G1Projective, Scalar, | ||
}; | ||
use blstrs::G1Affine; | ||
use ff::PrimeField; | ||
use group::Group; | ||
|
||
// Note: This is the same strategy that blst uses | ||
#[derive(Debug)] | ||
pub struct FixedBaseMSMPrecompWindow { | ||
table: Vec<Vec<G1Affine>>, | ||
wbits: usize, | ||
} | ||
|
||
impl FixedBaseMSMPrecompWindow { | ||
pub fn new(points: &[G1Affine], wbits: usize) -> Self { | ||
// For every point `P`, wbits indicates that we should compute | ||
// 1 * P, ..., (2^{wbits} - 1) * P | ||
// | ||
// The total amount of memory is roughly (numPoints * 2^{wbits} - 1) | ||
// where each point is 64 bytes. | ||
// | ||
let precomputed_points: Vec<_> = points | ||
.into_iter() | ||
.map(|point| Self::precompute_points(wbits, *point)) | ||
.collect(); | ||
|
||
Self { | ||
table: precomputed_points, | ||
wbits, | ||
} | ||
} | ||
// Given a point, we precompute P,..., (2^{w-1}-1) * P | ||
fn precompute_points(wbits: usize, point: G1Affine) -> Vec<G1Affine> { | ||
let mut lookup_table = Vec::with_capacity(1 << (wbits - 1)); | ||
|
||
// Convert to projective for faster operations | ||
let mut current = G1Projective::from(point); | ||
|
||
// Compute and store multiples | ||
for _ in 0..(1 << (wbits - 1)) { | ||
lookup_table.push(current); | ||
current += point; | ||
} | ||
|
||
g1_batch_normalize(&lookup_table) | ||
} | ||
|
||
pub fn msm(&self, scalars: &[Scalar]) -> G1Projective { | ||
let scalars_bytes: Vec<_> = scalars.iter().map(|a| a.to_bytes_le()).collect(); | ||
let number_of_windows = Scalar::NUM_BITS as usize / self.wbits + 1; | ||
|
||
let mut windows_of_points = vec![Vec::with_capacity(scalars.len()); number_of_windows]; | ||
|
||
for window_idx in 0..number_of_windows { | ||
for (scalar_idx, scalar_bytes) in scalars_bytes.iter().enumerate() { | ||
let sub_table = &self.table[scalar_idx]; | ||
let point_idx = get_booth_index(window_idx, self.wbits, scalar_bytes.as_ref()); | ||
|
||
if point_idx == 0 { | ||
continue; | ||
} | ||
let sign = point_idx.is_positive(); | ||
let point_idx = point_idx.unsigned_abs() as usize - 1; | ||
let mut point = sub_table[point_idx]; | ||
if !sign { | ||
point = -point; | ||
} | ||
|
||
windows_of_points[window_idx].push(point); | ||
} | ||
} | ||
|
||
let accumulated_points = multi_batch_addition_binary_tree_stride(windows_of_points); | ||
|
||
// Now accumulate the windows by doubling wbits times | ||
let mut result: G1Projective = *accumulated_points.last().unwrap(); | ||
for point in accumulated_points.into_iter().rev().skip(1) { | ||
// Double the result 'wbits' times | ||
for _ in 0..self.wbits { | ||
result = result.double(); | ||
} | ||
// Add the accumulated point for this window | ||
result += point; | ||
} | ||
|
||
result | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
use ff::Field; | ||
use group::prime::PrimeCurveAffine; | ||
|
||
#[test] | ||
fn precomp_lookup_table() { | ||
use group::Group; | ||
let lookup_table = FixedBaseMSMPrecompWindow::precompute_points(7, G1Affine::generator()); | ||
|
||
for i in 1..lookup_table.len() { | ||
let expected = G1Projective::generator() * Scalar::from((i + 1) as u64); | ||
assert_eq!(lookup_table[i], expected.into(),) | ||
} | ||
} | ||
|
||
#[test] | ||
fn msm_blst_precomp() { | ||
let length = 64; | ||
let generators: Vec<_> = (0..length) | ||
.map(|_| G1Projective::random(&mut rand::thread_rng()).into()) | ||
.collect(); | ||
let scalars: Vec<_> = (0..length) | ||
.map(|_| Scalar::random(&mut rand::thread_rng())) | ||
.collect(); | ||
|
||
let res = crate::lincomb::g1_lincomb(&generators, &scalars) | ||
.expect("number of generators and number of scalars is equal"); | ||
|
||
let fbm = FixedBaseMSMPrecompWindow::new(&generators, 7); | ||
let result = fbm.msm(&scalars); | ||
|
||
assert_eq!(res, result); | ||
} | ||
|
||
#[test] | ||
fn bench_window_sizes_msm() { | ||
let length = 64; | ||
let generators: Vec<_> = (0..length) | ||
.map(|_| G1Projective::random(&mut rand::thread_rng()).into()) | ||
.collect(); | ||
let scalars: Vec<_> = (0..length) | ||
.map(|_| Scalar::random(&mut rand::thread_rng())) | ||
.collect(); | ||
|
||
for i in 2..=14 { | ||
let fbm = FixedBaseMSMPrecompWindow::new(&generators, i); | ||
fbm.msm(&scalars); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters