Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: refactor to use const generics instead of GenericArray and typenum #125

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rust-toolchain
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.51.0
1.56.1
32 changes: 11 additions & 21 deletions src/batch_hasher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,39 +8,32 @@ use crate::poseidon::SimplePoseidonBatchHasher;
use crate::proteus::gpu::ClBatchHasher;
#[cfg(feature = "futhark")]
use crate::triton::{cl, gpu::GpuBatchHasher};
use crate::{Arity, BatchHasher, Strength, DEFAULT_STRENGTH};
use crate::{BatchHasher, Strength, DEFAULT_STRENGTH};
use blstrs::Scalar as Fr;
use generic_array::GenericArray;
#[cfg(feature = "futhark")]
use rust_gpu_tools::opencl;
use rust_gpu_tools::Device;

#[cfg(feature = "futhark")]
use triton::FutharkContext;

pub enum Batcher<A>
where
A: Arity<Fr>,
{
Cpu(SimplePoseidonBatchHasher<A>),
pub enum Batcher<const ARITY: usize, const WIDTH: usize> {
Cpu(SimplePoseidonBatchHasher<ARITY, WIDTH>),
#[cfg(feature = "futhark")]
OpenCl(GpuBatchHasher<A>),
OpenCl(GpuBatchHasher<ARITY, WIDTH>),
#[cfg(any(feature = "cuda", feature = "opencl"))]
OpenCl(ClBatchHasher<A>),
OpenCl(ClBatchHasher<ARITY, WIDTH>),
}

impl<A> Batcher<A>
where
A: Arity<Fr>,
{
impl<const ARITY: usize, const WIDTH: usize> Batcher<ARITY, WIDTH> {
/// Create a new CPU batcher.
pub fn new_cpu(max_batch_size: usize) -> Self {
Self::with_strength_cpu(DEFAULT_STRENGTH, max_batch_size)
}

/// Create a new CPU batcher with a specified strength.
pub fn with_strength_cpu(strength: Strength, max_batch_size: usize) -> Self {
Self::Cpu(SimplePoseidonBatchHasher::<A>::new_with_strength(
Self::Cpu(SimplePoseidonBatchHasher::new_with_strength(
strength,
max_batch_size,
))
Expand All @@ -50,7 +43,7 @@ where
#[cfg(feature = "futhark")]
pub fn pick_gpu(max_batch_size: usize) -> Result<Self, Error> {
let futhark_context = cl::default_futhark_context()?;
Ok(Self::OpenCl(GpuBatchHasher::<A>::new_with_strength(
Ok(Self::OpenCl(GpuBatchHasher::new_with_strength(
futhark_context,
DEFAULT_STRENGTH,
max_batch_size,
Expand Down Expand Up @@ -103,19 +96,16 @@ where
strength: Strength,
max_batch_size: usize,
) -> Result<Self, Error> {
Ok(Self::OpenCl(ClBatchHasher::<A>::new_with_strength(
Ok(Self::OpenCl(ClBatchHasher::new_with_strength(
device,
strength,
max_batch_size,
)?))
}
}

impl<A> BatchHasher<A> for Batcher<A>
where
A: Arity<Fr>,
{
fn hash(&mut self, preimages: &[GenericArray<Fr, A>]) -> Result<Vec<Fr>, Error> {
impl<const ARITY: usize, const WIDTH: usize> BatchHasher<ARITY, WIDTH> for Batcher<ARITY, WIDTH> {
fn hash(&mut self, preimages: &[[Fr; ARITY]]) -> Result<Vec<Fr>, Error> {
match self {
Batcher::Cpu(batcher) => batcher.hash(preimages),
#[cfg(any(feature = "futhark", feature = "cuda", feature = "opencl"))]
Expand Down
137 changes: 65 additions & 72 deletions src/circuit.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
use std::ops::{AddAssign, MulAssign};

use crate::hash_type::HashType;
use crate::matrix::Matrix;
use crate::mds::SparseMatrix;
use crate::poseidon::{Arity, PoseidonConstants};
use crate::poseidon::PoseidonConstants;
use bellperson::gadgets::boolean::Boolean;
use bellperson::gadgets::num;
use bellperson::gadgets::num::AllocatedNum;
use bellperson::{ConstraintSystem, LinearCombination, SynthesisError};
use ff::{Field, PrimeField};
use std::marker::PhantomData;

use bellperson::{ConstraintSystem, SynthesisError};
use ff::PrimeField;

/// Similar to `num::Num`, we use `Elt` to accumulate both values and linear combinations, then eventually
/// extract into a `num::AllocatedNum`, enforcing that the linear combination corresponds to the result.
Expand All @@ -22,10 +20,12 @@ enum Elt<Scalar: PrimeField> {
}

impl<Scalar: PrimeField> Elt<Scalar> {
#[allow(unused)]
fn is_allocated(&self) -> bool {
matches!(self, Self::Allocated(_))
}

#[allow(unused)]
fn is_num(&self) -> bool {
matches!(self, Self::Num(_))
}
Expand Down Expand Up @@ -67,10 +67,13 @@ impl<Scalar: PrimeField> Elt<Scalar> {
}
}

fn lc(&self) -> LinearCombination<Scalar> {
#[cfg(test)]
fn lc(&self) -> bellperson::LinearCombination<Scalar> {
match self {
Self::Num(num) => num.lc(Scalar::one()),
Self::Allocated(v) => LinearCombination::<Scalar>::zero() + v.get_variable(),
Self::Allocated(v) => {
bellperson::LinearCombination::<Scalar>::zero() + v.get_variable()
}
}
}

Expand Down Expand Up @@ -99,38 +102,33 @@ impl<Scalar: PrimeField> Elt<Scalar> {
}

/// Circuit for Poseidon hash.
pub struct PoseidonCircuit<'a, Scalar, A>
pub struct PoseidonCircuit<'a, Scalar, const ARITY: usize, const WIDTH: usize>
where
Scalar: PrimeField,
A: Arity<Scalar>,
{
constants_offset: usize,
width: usize,
elements: Vec<Elt<Scalar>>,
pos: usize,
current_round: usize,
constants: &'a PoseidonConstants<Scalar, A>,
_w: PhantomData<A>,
constants: &'a PoseidonConstants<Scalar, ARITY, WIDTH>,
}

/// PoseidonCircuit implementation.
impl<'a, Scalar, A> PoseidonCircuit<'a, Scalar, A>
impl<'a, Scalar, const ARITY: usize, const WIDTH: usize> PoseidonCircuit<'a, Scalar, ARITY, WIDTH>
where
Scalar: PrimeField,
A: Arity<Scalar>,
{
/// Create a new Poseidon hasher for `preimage`.
fn new(elements: Vec<Elt<Scalar>>, constants: &'a PoseidonConstants<Scalar, A>) -> Self {
let width = constants.width();

fn new(
elements: Vec<Elt<Scalar>>,
constants: &'a PoseidonConstants<Scalar, ARITY, WIDTH>,
) -> Self {
PoseidonCircuit {
constants_offset: 0,
width,
width: WIDTH,
elements,
pos: width,
current_round: 0,
constants,
_w: PhantomData::<A>,
}
}

Expand Down Expand Up @@ -333,36 +331,34 @@ where
Ok(())
}

#[allow(unused)]
fn debug(&self) {
let element_frs: Vec<_> = self.elements.iter().map(|n| n.val()).collect::<Vec<_>>();
dbg!(element_frs, self.constants_offset);
}
}

/// Create circuit for Poseidon hash.
pub fn poseidon_hash<CS, Scalar, A>(
pub fn poseidon_hash<CS, F, const ARITY: usize, const WIDTH: usize>(
mut cs: CS,
preimage: Vec<AllocatedNum<Scalar>>,
constants: &PoseidonConstants<Scalar, A>,
) -> Result<AllocatedNum<Scalar>, SynthesisError>
preimage: Vec<AllocatedNum<F>>,
constants: &PoseidonConstants<F, ARITY, WIDTH>,
) -> Result<AllocatedNum<F>, SynthesisError>
where
CS: ConstraintSystem<Scalar>,
Scalar: PrimeField,
A: Arity<Scalar>,
CS: ConstraintSystem<F>,
F: PrimeField,
{
let arity = A::to_usize();
let tag_element = Elt::num_from_fr::<CS>(constants.domain_tag);
let mut elements = Vec::with_capacity(arity + 1);
let mut elements = Vec::with_capacity(WIDTH);
elements.push(tag_element);
elements.extend(preimage.into_iter().map(Elt::Allocated));

if let HashType::ConstantLength(length) = constants.hash_type {
assert!(length <= arity, "illegal length: constants are malformed");
assert!(length <= ARITY, "illegal length: constants are malformed");
// Add zero-padding.
for i in 0..(arity - length) {
let allocated = AllocatedNum::alloc(cs.namespace(|| format!("padding {}", i)), || {
Ok(Scalar::zero())
})?;
for i in 0..(ARITY - length) {
let allocated =
AllocatedNum::alloc(cs.namespace(|| format!("padding {}", i)), || Ok(F::zero()))?;
let elt = Elt::Allocated(allocated);
elements.push(elt);
}
Expand Down Expand Up @@ -564,13 +560,14 @@ where
Ok(res)
}

fn scalar_product_with_add<Scalar: PrimeField, CS: ConstraintSystem<Scalar>>(
elts: &[Elt<Scalar>],
scalars: &[Scalar],
to_add: Scalar,
) -> Result<Elt<Scalar>, SynthesisError> {
let tmp = scalar_product::<Scalar, CS>(elts, scalars)?;
let tmp2 = tmp.add::<CS>(Elt::<Scalar>::num_from_fr::<CS>(to_add))?;
#[cfg(test)]
fn scalar_product_with_add<F: PrimeField, CS: ConstraintSystem<F>>(
elts: &[Elt<F>],
scalars: &[F],
to_add: F,
) -> Result<Elt<F>, SynthesisError> {
let tmp = scalar_product::<F, CS>(elts, scalars)?;
let tmp2 = tmp.add::<CS>(Elt::num_from_fr::<CS>(to_add))?;

Ok(tmp2)
}
Expand All @@ -594,51 +591,48 @@ mod tests {
use bellperson::util_cs::test_cs::TestConstraintSystem;
use bellperson::ConstraintSystem;
use blstrs::Scalar as Fr;
use generic_array::typenum;
use ff::Field;
use rand::SeedableRng;
use rand_xorshift::XorShiftRng;

#[test]
fn test_poseidon_hash() {
test_poseidon_hash_aux::<typenum::U2>(Strength::Standard, 311, false);
test_poseidon_hash_aux::<typenum::U4>(Strength::Standard, 377, false);
test_poseidon_hash_aux::<typenum::U8>(Strength::Standard, 505, false);
test_poseidon_hash_aux::<typenum::U16>(Strength::Standard, 761, false);
test_poseidon_hash_aux::<typenum::U24>(Strength::Standard, 1009, false);
test_poseidon_hash_aux::<typenum::U36>(Strength::Standard, 1385, false);

test_poseidon_hash_aux::<typenum::U2>(Strength::Strengthened, 367, false);
test_poseidon_hash_aux::<typenum::U4>(Strength::Strengthened, 433, false);
test_poseidon_hash_aux::<typenum::U8>(Strength::Strengthened, 565, false);
test_poseidon_hash_aux::<typenum::U16>(Strength::Strengthened, 821, false);
test_poseidon_hash_aux::<typenum::U24>(Strength::Strengthened, 1069, false);
test_poseidon_hash_aux::<typenum::U36>(Strength::Strengthened, 1445, false);

test_poseidon_hash_aux::<typenum::U15>(Strength::Standard, 730, true);
test_poseidon_hash_aux::<2, 3>(Strength::Standard, 311, false);
test_poseidon_hash_aux::<4, 5>(Strength::Standard, 377, false);
test_poseidon_hash_aux::<8, 9>(Strength::Standard, 505, false);
test_poseidon_hash_aux::<16, 17>(Strength::Standard, 761, false);
test_poseidon_hash_aux::<24, 25>(Strength::Standard, 1009, false);
test_poseidon_hash_aux::<36, 37>(Strength::Standard, 1385, false);

test_poseidon_hash_aux::<2, 3>(Strength::Strengthened, 367, false);
test_poseidon_hash_aux::<4, 5>(Strength::Strengthened, 433, false);
test_poseidon_hash_aux::<8, 9>(Strength::Strengthened, 565, false);
test_poseidon_hash_aux::<16, 17>(Strength::Strengthened, 821, false);
test_poseidon_hash_aux::<24, 25>(Strength::Strengthened, 1069, false);
test_poseidon_hash_aux::<36, 37>(Strength::Strengthened, 1445, false);

test_poseidon_hash_aux::<15, 16>(Strength::Standard, 730, true);
}

fn test_poseidon_hash_aux<A>(
fn test_poseidon_hash_aux<const ARITY: usize, const WIDTH: usize>(
strength: Strength,
expected_constraints: usize,
constant_length: bool,
) where
A: Arity<Fr>,
{
) {
let mut rng = XorShiftRng::from_seed(crate::TEST_SEED);
let arity = A::to_usize();
let constants_x = if constant_length {
PoseidonConstants::<Fr, A>::new_with_strength_and_type(
PoseidonConstants::<Fr, ARITY, WIDTH>::new_with_strength_and_type(
strength,
HashType::ConstantLength(arity),
HashType::ConstantLength(ARITY),
)
} else {
PoseidonConstants::<Fr, A>::new_with_strength(strength)
PoseidonConstants::<Fr, ARITY, WIDTH>::new_with_strength(strength)
};

let range = if constant_length {
1..=arity
1..=ARITY
} else {
arity..=arity
ARITY..=ARITY
};
for preimage_length in range {
let mut cs = TestConstraintSystem::<Fr>::new();
Expand All @@ -650,12 +644,11 @@ mod tests {
};
let expected_constraints_calculated = {
let arity_tag_constraints = 0;
let width = 1 + arity;
// The '- 1' term represents the first s-box for the arity tag, which is a constant and needs no constraint.
let s_boxes = (width * constants.full_rounds) + constants.partial_rounds - 1;
let s_boxes = (WIDTH * constants.full_rounds) + constants.partial_rounds - 1;
let s_box_constraints = 3 * s_boxes;
let mds_constraints =
(width * constants.full_rounds) + constants.partial_rounds - arity;
(WIDTH * constants.full_rounds) + constants.partial_rounds - ARITY;
arity_tag_constraints + s_box_constraints + mds_constraints
};
let mut i = 0;
Expand All @@ -673,7 +666,7 @@ mod tests {

let out = poseidon_hash(&mut cs, data, &constants).expect("poseidon hashing failed");

let mut p = Poseidon::<Fr, A>::new_with_preimage(&fr_data, &constants);
let mut p = Poseidon::<Fr, ARITY, WIDTH>::new_with_preimage(&fr_data, &constants);
let expected: Fr = p.hash_in_mode(HashMode::Correct);

assert!(cs.is_satisfied(), "constraints not satisfied");
Expand Down
Loading