Skip to content

Commit

Permalink
r1cs: remove optimized matmul
Browse files Browse the repository at this point in the history
From Section 6.2.1 of the Poseidon paper:

"we do not need more constraints [than that from the SBoxes] as
the linear layers and round constants can be incorporated into
these ones"

This means that the optimized matmul is only useful out of circuit,
where we currently use it.
  • Loading branch information
redshiftzero committed Jun 8, 2023
1 parent 3280081 commit dd6d38b
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 209 deletions.
188 changes: 24 additions & 164 deletions poseidon-permutation/src/r1cs.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#![allow(non_snake_case)]
use ark_ff::PrimeField;
use ark_std::{vec, vec::Vec};
use ark_std::{dbg, vec::Vec};

use ark_r1cs_std::{fields::fp::FpVar, prelude::*};
use ark_relations::r1cs::ConstraintSystemRef;
Expand All @@ -10,8 +10,6 @@ use poseidon_parameters::v1::{Alpha, MatrixOperations, PoseidonParameters};
pub struct InstanceVar<F: PrimeField> {
/// Parameters for this instance of Poseidon.
pub parameters: PoseidonParameters<F>,
/// Constant FpVar representing the first element of the MDS matrix (used in optimized version).
M00_var: FpVar<F>,

/// Constraint system
pub cs: ConstraintSystemRef<F>,
Expand All @@ -24,63 +22,34 @@ impl<F> InstanceVar<F>
where
F: PrimeField,
{
/// Initialize a new Poseidon instance.
pub fn new(parameters: PoseidonParameters<F>, cs: ConstraintSystemRef<F>) -> Self {
let zero = FpVar::<F>::zero();
// t = rate + capacity
let state_words = vec![zero; parameters.t];
let M00_var = FpVar::<F>::new_constant(cs.clone(), parameters.optimized_mds.M_00)
.expect("can create constant");

Self {
parameters,
cs,
state_words,
M00_var,
}
}

/// Fixed width hash from n:1. Outputs a F given `t` input words.
pub fn n_to_1_fixed_hash(&mut self, input_words: Vec<FpVar<F>>) -> FpVar<F> {
pub fn n_to_1_fixed_hash(
parameters: PoseidonParameters<F>,
cs: ConstraintSystemRef<F>,
input_words: Vec<FpVar<F>>,
) -> FpVar<F> {
// Check input words are `t` elements long
if input_words.len() != self.parameters.t {
if input_words.len() != parameters.t {
panic!("err: input words must be t elements long")
}

// Set internal state words.
for (i, input_word) in input_words.into_iter().enumerate() {
self.state_words[i] = input_word
}

// Apply Poseidon permutation.
//self.unoptimized_permute();
self.permute();

// Emit a single element since this is a n:1 hash.
self.state_words[1].clone()
}

/// Fixed width hash from n:1. Outputs a F given `t` input words.
pub fn unoptimized_n_to_1_fixed_hash(&mut self, input_words: Vec<FpVar<F>>) -> FpVar<F> {
// Check input words are `t` elements long
if input_words.len() != self.parameters.t {
panic!("err: input words must be t elements long")
}
// t = rate + capacity

// Set internal state words.
for (i, input_word) in input_words.into_iter().enumerate() {
self.state_words[i] = input_word
}
let mut instance = InstanceVar {
parameters,
cs,
state_words: input_words,
};

// Apply Poseidon permutation.
self.unoptimized_permute();
instance.permute();

// Emit a single element since this is a n:1 hash.
self.state_words[1].clone()
instance.state_words[1].clone()
}

/// Unoptimized Poseidon permutation.
pub fn unoptimized_permute(&mut self) {
/// Poseidon permutation.
pub fn permute(&mut self) {
let R_f = self.parameters.rounds.full() / 2;
let R_P = self.parameters.rounds.partial();
let mut round_constants_counter = 0;
Expand All @@ -94,7 +63,9 @@ where
self.state_words[i] += round_constants[round_constants_counter];
round_constants_counter += 1;
}
dbg!("bout to SubWord: ", self.cs.num_constraints());
self.full_sub_words();
dbg!("after x^a: ", self.cs.num_constraints());
self.mix_layer_mds();
}

Expand All @@ -121,60 +92,6 @@ where
}
}

/// Permutes the internal state.
///
/// This implementation is based on the optimized Sage implementation
/// `poseidonperm_x3_64_optimized.sage` provided in Appendix B of the Poseidon paper.
fn permute(&mut self) {
let R_f = self.parameters.rounds.full() / 2;
let R_P = self.parameters.rounds.partial();
let round_constants = self.parameters.optimized_arc.0.clone();
let t = self.parameters.t;

// First chunk of full rounds
for r in 0..R_f {
// Apply `AddRoundConstants` layer
for i in 0..t {
self.state_words[i] += round_constants.get_element(r, i);
}
self.full_sub_words();
self.mix_layer_mds();
}
let mut round_constants_counter = R_f;

// Partial rounds
// First part of `AddRoundConstants` layer
for i in 0..t {
self.state_words[i] += round_constants.get_element(round_constants_counter, i);
}
// First full matrix multiplication.
self.mix_layer_mi();

for r in 0..R_P - 1 {
self.partial_sub_words();
// Rest of `AddRoundConstants` layer, moved to after the S-box layer
round_constants_counter += 1;
self.state_words[0] += round_constants.get_element(round_constants_counter, 0);
self.sparse_mat_mul(R_P - r - 1);
}

// Last partial round
self.partial_sub_words();
self.sparse_mat_mul(0);
round_constants_counter += 1;

// Final full rounds
for _ in 0..R_f {
// Apply `AddRoundConstants` layer
for i in 0..t {
self.state_words[i] += round_constants.get_element(round_constants_counter, i);
}
self.full_sub_words();
self.mix_layer_mds();
round_constants_counter += 1;
}
}

/// Applies the partial `SubWords` layer.
fn partial_sub_words(&mut self) {
match self.parameters.alpha {
Expand All @@ -191,11 +108,11 @@ where
fn full_sub_words(&mut self) {
match self.parameters.alpha {
Alpha::Exponent(exp) => {
self.state_words = self
.state_words
.iter()
.map(|x| x.pow_by_constant([exp as u64]).expect("can compute pow"))
.collect()
for i in 0..self.parameters.t {
self.state_words[i] = (self.state_words[i])
.pow_by_constant([exp as u64])
.expect("can compute pow");
}
}
Alpha::Inverse => {
unimplemented!("err: inverse alpha not implemented")
Expand Down Expand Up @@ -225,61 +142,4 @@ where
})
.collect();
}

/// Applies the `MixLayer` using the M_i matrix.
fn mix_layer_mi(&mut self) {
self.state_words = self
.parameters
.optimized_mds
.M_i
.iter_rows()
.map(|row| {
let temp_vec: Vec<FpVar<F>> = row
.iter()
.zip(&self.state_words)
.map(|(x, y)| {
FpVar::<F>::new_constant(self.cs.clone(), x).expect("can create constant")
* y
})
.collect();
let result = temp_vec.iter().sum();
result
})
.collect();
}

/// This is `cheap_matrix_mul` in the Sage spec
fn sparse_mat_mul(&mut self, round_number: usize) {
// mul_row = [(state_words[0] * v[i]) for i in range(0, t-1)]
// add_row = [(mul_row[i] + state_words[i+1]) for i in range(0, t-1)]
let add_row: Vec<FpVar<F>> = self.parameters.optimized_mds.v_collection[round_number]
.elements
.iter()
.enumerate()
.map(|(i, x)| {
FpVar::<F>::new_constant(self.cs.clone(), x).expect("can create constant")
* &self.state_words[0]
+ &self.state_words[i + 1]
})
.collect();

// column_1 = [M_0_0] + w_hat
// state_words_new[0] = sum([column_1[i] * state_words[i] for i in range(0, t)])
// state_words_new = [state_words_new[0]] + add_row
let temp_vec: Vec<FpVar<F>> = self.parameters.optimized_mds.w_hat_collection[round_number]
.elements
.iter()
.zip(self.state_words[1..self.parameters.t].iter())
.map(|(x, y)| {
FpVar::<F>::new_constant(self.cs.clone(), x).expect("can create constant") * y
})
.collect();
self.state_words[0] =
&self.M00_var * &self.state_words[0] + temp_vec.iter().sum::<FpVar<F>>();

// self.state_words[1..self.parameters.t].copy_from_slice(&add_row[..(self.parameters.t - 1)]);
for index in 1..self.parameters.t {
self.state_words[index] = add_row[index - 1].clone();
}
}
}
103 changes: 59 additions & 44 deletions poseidon377/src/r1cs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,90 +9,105 @@ pub fn hash_1(
domain_separator: &FqVar,
value: FqVar,
) -> Result<FqVar, SynthesisError> {
let mut state = InstanceVar::new(crate::RATE_1_PARAMS.clone(), cs.clone());
Ok(state.n_to_1_fixed_hash(vec![domain_separator.clone(), value]))
Ok(InstanceVar::n_to_1_fixed_hash(
crate::RATE_1_PARAMS.clone(),
cs.clone(),
vec![domain_separator.clone(), value],
))
}

pub fn hash_2(
cs: ConstraintSystemRef<Fq>,
domain_separator: &FqVar,
value: (FqVar, FqVar),
) -> Result<FqVar, SynthesisError> {
let mut state = InstanceVar::new(crate::RATE_2_PARAMS.clone(), cs.clone());
Ok(state.n_to_1_fixed_hash(vec![domain_separator.clone(), value.0, value.1]))
Ok(InstanceVar::n_to_1_fixed_hash(
crate::RATE_2_PARAMS.clone(),
cs.clone(),
vec![domain_separator.clone(), value.0, value.1],
))
}

pub fn hash_3(
cs: ConstraintSystemRef<Fq>,
domain_separator: &FqVar,
value: (FqVar, FqVar, FqVar),
) -> Result<FqVar, SynthesisError> {
let mut state = InstanceVar::new(crate::RATE_3_PARAMS.clone(), cs.clone());
Ok(state.n_to_1_fixed_hash(vec![domain_separator.clone(), value.0, value.1, value.2]))
Ok(InstanceVar::n_to_1_fixed_hash(
crate::RATE_3_PARAMS.clone(),
cs.clone(),
vec![domain_separator.clone(), value.0, value.1, value.2],
))
}

pub fn hash_4(
cs: ConstraintSystemRef<Fq>,
domain_separator: &FqVar,
value: (FqVar, FqVar, FqVar, FqVar),
) -> Result<FqVar, SynthesisError> {
let mut state = InstanceVar::new(crate::RATE_4_PARAMS.clone(), cs.clone());
Ok(state.n_to_1_fixed_hash(vec![
domain_separator.clone(),
value.0,
value.1,
value.2,
value.3,
]))
Ok(InstanceVar::n_to_1_fixed_hash(
crate::RATE_4_PARAMS.clone(),
cs.clone(),
vec![domain_separator.clone(), value.0, value.1, value.2, value.3],
))
}

pub fn hash_5(
cs: ConstraintSystemRef<Fq>,
domain_separator: &FqVar,
value: (FqVar, FqVar, FqVar, FqVar, FqVar),
) -> Result<FqVar, SynthesisError> {
let mut state = InstanceVar::new(crate::RATE_5_PARAMS.clone(), cs.clone());
Ok(state.n_to_1_fixed_hash(vec![
domain_separator.clone(),
value.0,
value.1,
value.2,
value.3,
value.4,
]))
Ok(InstanceVar::n_to_1_fixed_hash(
crate::RATE_5_PARAMS.clone(),
cs.clone(),
vec![
domain_separator.clone(),
value.0,
value.1,
value.2,
value.3,
value.4,
],
))
}

pub fn hash_6(
cs: ConstraintSystemRef<Fq>,
domain_separator: &FqVar,
value: (FqVar, FqVar, FqVar, FqVar, FqVar, FqVar),
) -> Result<FqVar, SynthesisError> {
let mut state = InstanceVar::new(crate::RATE_6_PARAMS.clone(), cs.clone());
Ok(state.n_to_1_fixed_hash(vec![
domain_separator.clone(),
value.0,
value.1,
value.2,
value.3,
value.4,
value.5,
]))
Ok(InstanceVar::n_to_1_fixed_hash(
crate::RATE_6_PARAMS.clone(),
cs.clone(),
vec![
domain_separator.clone(),
value.0,
value.1,
value.2,
value.3,
value.4,
value.5,
],
))
}

pub fn hash_7(
cs: ConstraintSystemRef<Fq>,
domain_separator: &FqVar,
value: (FqVar, FqVar, FqVar, FqVar, FqVar, FqVar, FqVar),
) -> Result<FqVar, SynthesisError> {
let mut state = InstanceVar::new(crate::RATE_7_PARAMS.clone(), cs.clone());
Ok(state.n_to_1_fixed_hash(vec![
domain_separator.clone(),
value.0,
value.1,
value.2,
value.3,
value.4,
value.5,
value.6,
]))
Ok(InstanceVar::n_to_1_fixed_hash(
crate::RATE_7_PARAMS.clone(),
cs.clone(),
vec![
domain_separator.clone(),
value.0,
value.1,
value.2,
value.3,
value.4,
value.5,
value.6,
],
))
}
Loading

0 comments on commit dd6d38b

Please sign in to comment.