Skip to content

Commit

Permalink
implement LogUp approach and padding
Browse files Browse the repository at this point in the history
  • Loading branch information
hero78119 committed Jan 22, 2024
1 parent 29afc38 commit 3d2313d
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 1,023 deletions.
101 changes: 71 additions & 30 deletions src/gadgets/lookup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ impl<E: Engine> LookupTrace<E> {
let read_value_term = gamma.mul(cs.namespace(|| "read_value_term"), read_value)?;
// counter_term = gamma^2 * counter
let read_counter_term = gamma_square.mul(cs.namespace(|| "read_counter_term"), read_counter)?;
// new_R = R * (alpha - (addr + gamma * value + gamma^2 * counter))
// new_R = R + 1 / (alpha + (addr + gamma * value + gamma^2 * counter))
let new_R = AllocatedNum::alloc(cs.namespace(|| "new_R"), || {
prev_R
.get_value()
Expand All @@ -291,20 +291,23 @@ impl<E: Engine> LookupTrace<E> {
.zip(read_value_term.get_value())
.zip(read_counter_term.get_value())
.map(|((((R, alpha), addr), value_term), counter_term)| {
R * (alpha - (addr + value_term + counter_term))
R + (alpha + (addr + value_term + counter_term))
.invert()
.expect("invert failed due to read term is 0") // negilible probability for invert failed
})
.ok_or(SynthesisError::AssignmentMissing)
})?;
let mut r_blc = LinearCombination::<F>::zero();
r_blc = r_blc + alpha.get_variable()
- addr.get_variable()
- read_value_term.get_variable()
- read_counter_term.get_variable();
r_blc = r_blc
+ alpha.get_variable()
+ addr.get_variable()
+ read_value_term.get_variable()
+ read_counter_term.get_variable();
cs.enforce(
|| "R update",
|lc| lc + prev_R.get_variable(),
|lc| lc + new_R.get_variable() - prev_R.get_variable(),
|_| r_blc,
|lc| lc + new_R.get_variable(),
|lc| lc + CS::one(),
);

let alloc_num_one = alloc_one(cs.namespace(|| "one"));
Expand Down Expand Up @@ -350,29 +353,35 @@ impl<E: Engine> LookupTrace<E> {
.zip(gamma_square.get_value())
.map(
|(((((W, alpha), addr), value_term), write_counter_term), gamma_square)| {
W * (alpha - (addr + value_term + write_counter_term + gamma_square))
W + (alpha + (addr + value_term + write_counter_term + gamma_square))
.invert()
.expect("invert failed due to write term is 0") // negilible probability for invert failed
},
)
.ok_or(SynthesisError::AssignmentMissing)
})?;
// new_W = W * (alpha - (addr + gamma * value + gamma^2 * counter + gamma^2)))
// new_W = W + 1 / (alpha - (addr + gamma * value + gamma^2 * counter))
let mut w_blc = LinearCombination::<F>::zero();
w_blc = w_blc + alpha.get_variable()
- addr.get_variable()
- write_value_term.get_variable()
- write_counter_term.get_variable()
- gamma_square.get_variable();
w_blc = w_blc
+ alpha.get_variable()
+ addr.get_variable()
+ write_value_term.get_variable()
+ write_counter_term.get_variable()
+ gamma_square.get_variable();
cs.enforce(
|| "W update",
|lc| lc + prev_W.get_variable(),
|lc| lc + new_W.get_variable() - prev_W.get_variable(),
|_| w_blc,
|lc| lc + new_W.get_variable(),
|lc| lc + CS::one(),
);

let new_rw_counter = add_allocated_num(
cs.namespace(|| "new_rw_counter"),
&write_counter,
&alloc_num_one,
)?;

// update accu
Ok((new_R, new_W, new_rw_counter))
}
}
Expand Down Expand Up @@ -558,7 +567,11 @@ impl<F: PrimeField> Lookup<F> {
Self {
map_aux: initial_table
.into_iter()
.map(|(addr, value)| (addr, (value, F::ZERO)))
.enumerate()
.map(|(i, (addr, value))| {
assert!(F::from(i as u64) == addr);
(addr, (value, F::ZERO))
})
.collect(),
rw_counter: F::ZERO,
table_type,
Expand All @@ -571,6 +584,19 @@ impl<F: PrimeField> Lookup<F> {
self.map_aux.len()
}

/// padding
pub fn padding(&mut self, N: usize)
where
F: Ord,
{
assert!(self.map_aux.len() <= N);
(self.map_aux.len()..N).for_each(|addr| {
self
.map_aux
.insert(F::from(addr as u64), (F::ZERO, F::ZERO));
});
}

/// table values
pub fn values(&self) -> Values<'_, F, (F, F)> {
self.map_aux.values()
Expand Down Expand Up @@ -829,8 +855,12 @@ mod test {
.zip(addr.get_value())
.zip(read_value.get_value())
.map(|((((prev_R, alpha), gamma), addr), read_value)| prev_R
* (alpha - (addr + gamma * read_value + gamma * gamma * <E1 as Engine>::Scalar::ZERO))
* (alpha - (addr + gamma * read_value + gamma * gamma * <E1 as Engine>::Scalar::ONE)))
+ (alpha + (addr + gamma * read_value + gamma * gamma * <E1 as Engine>::Scalar::ZERO))
.invert()
.unwrap()
+ (alpha + (addr + gamma * read_value + gamma * gamma * <E1 as Engine>::Scalar::ONE))
.invert()
.unwrap())
);
// next_W check
assert_eq!(
Expand All @@ -843,9 +873,13 @@ mod test {
.zip(read_value.get_value())
.map(|((((prev_W, alpha), gamma), addr), read_value)| {
prev_W
* (alpha - (addr + gamma * read_value + gamma * gamma * (<E1 as Engine>::Scalar::ONE)))
* (alpha
- (addr + gamma * read_value + gamma * gamma * (<E1 as Engine>::Scalar::from(2))))
+ (alpha + (addr + gamma * read_value + gamma * gamma * (<E1 as Engine>::Scalar::ONE)))
.invert()
.unwrap()
+ (alpha
+ (addr + gamma * read_value + gamma * gamma * (<E1 as Engine>::Scalar::from(2))))
.invert()
.unwrap()
}),
);

Expand Down Expand Up @@ -905,7 +939,6 @@ mod test {
write_value_1.get_value().unwrap(),
);
let read_value = lookup_trace_builder.read(addr.get_value().unwrap());
// cs.namespace(|| "read_value 1"),
assert_eq!(read_value, <E1 as Engine>::Scalar::from(101));
let (_, mut lookup_trace) = lookup_trace_builder.snapshot::<E2>(
ro_consts.clone(),
Expand Down Expand Up @@ -948,11 +981,15 @@ mod test {
.zip(addr.get_value())
.zip(read_value.get_value())
.map(|((((prev_R, alpha), gamma), addr), read_value)| prev_R
* (alpha
- (addr
+ (alpha
+ (addr
+ gamma * <E1 as Engine>::Scalar::ZERO
+ gamma * gamma * <E1 as Engine>::Scalar::ZERO))
* (alpha - (addr + gamma * read_value + gamma * gamma * <E1 as Engine>::Scalar::ONE)))
.invert()
.unwrap()
+ (alpha + (addr + gamma * read_value + gamma * gamma * <E1 as Engine>::Scalar::ONE))
.invert()
.unwrap())
);
// next_W check
assert_eq!(
Expand All @@ -965,9 +1002,13 @@ mod test {
.zip(read_value.get_value())
.map(|((((prev_W, alpha), gamma), addr), read_value)| {
prev_W
* (alpha - (addr + gamma * read_value + gamma * gamma * (<E1 as Engine>::Scalar::ONE)))
* (alpha
- (addr + gamma * read_value + gamma * gamma * (<E1 as Engine>::Scalar::from(2))))
+ (alpha + (addr + gamma * read_value + gamma * gamma * (<E1 as Engine>::Scalar::ONE)))
.invert()
.unwrap()
+ (alpha
+ (addr + gamma * read_value + gamma * gamma * (<E1 as Engine>::Scalar::from(2))))
.invert()
.unwrap()
}),
);

Expand Down
39 changes: 23 additions & 16 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ where
/// let circuit2 = TrivialCircuit::<<E2 as Engine>::Scalar>::default();
/// // Only relevant for a SNARK using computation commitmnets, pass &(|_| 0)
/// // or &*nova_snark::traits::snark::default_ck_hint() otherwise.
/// let ck_hint1 = &*SPrime::<E1>::ck_floor();
/// let ck_hint2 = &*SPrime::<E2>::ck_floor();
/// let ck_hint1 = &*SPrime::<E1, EE<_>>::ck_floor();
/// let ck_hint2 = &*SPrime::<E2, EE<_>>::ck_floor();
///
/// let pp = PublicParams::setup(&circuit1, &circuit2, ck_hint1, ck_hint2);
/// ```
Expand Down Expand Up @@ -1055,7 +1055,10 @@ where
VerifierKeyV2<E1, E2, C1, C2, S1, S2>,
),
NovaError,
> {
>
where
<E1 as Engine>::Scalar: Ord,
{
let (pk_primary, vk_primary) = S1::setup(
&pp.ck_primary,
&pp.circuit_shape_primary.r1cs_shape,
Expand Down Expand Up @@ -1119,8 +1122,8 @@ where
challenges,
read_row,
write_row,
initial_table,
final_table,
initial_table.clone(),
final_table.clone(),
)
},
|| {
Expand Down Expand Up @@ -2340,15 +2343,23 @@ mod tests {

let circuit_secondary = TrivialCircuit::default();

let ck_hint1 = &*SPrime::<E1, EE<_>>::ck_floor();
let ck_hint2 = &*SPrime::<E2, EE<_>>::ck_floor();

// produce public parameters
let pp =
PublicParams::<E1, E2, HeapifyCircuit<E1, E2>, TrivialCircuit<<E2 as Engine>::Scalar>>::setup(
&circuit_primaries[0],
&circuit_secondary,
&*default_ck_hint(),
&*default_ck_hint(),
ck_hint1,
ck_hint2,
);

// produce the prover and verifier keys for compressed snark
let (pk, vk) =
CompressedSNARKV2::<_, _, _, _, S1<E1, EE<E1>>, S2<E2, EE<E2>>>::setup(&pp, &initial_table)
.unwrap();

let z0_primary =
HeapifyCircuit::<E1, E2>::get_z0(&pp.ck_primary, &final_table, expected_intermediate_gamma);

Expand Down Expand Up @@ -2399,10 +2410,10 @@ mod tests {
zn_primary[5]
); // rw counter = number_of_iterated_nodes * (3r + 4w) operations

assert_eq!(pp.circuit_shape_primary.r1cs_shape.num_cons, 12599);
assert_eq!(pp.circuit_shape_primary.r1cs_shape.num_vars, 12607);
assert_eq!(pp.circuit_shape_secondary.r1cs_shape.num_cons, 10347);
assert_eq!(pp.circuit_shape_secondary.r1cs_shape.num_vars, 10329);
assert_eq!(pp.circuit_shape_primary.r1cs_shape.num_cons, 12609);
assert_eq!(pp.circuit_shape_primary.r1cs_shape.num_vars, 12615);
assert_eq!(pp.circuit_shape_secondary.r1cs_shape.num_cons, 10357);
assert_eq!(pp.circuit_shape_secondary.r1cs_shape.num_vars, 10337);

println!("zn_primary {:?}", zn_primary);

Expand All @@ -2416,11 +2427,6 @@ mod tests {
"expected_intermediate_gamma != intermediate_gamma"
);

// produce the prover and verifier keys for compressed snark
let (pk, vk) =
CompressedSNARKV2::<_, _, _, _, S1<E1, EE<E1>>, S2<E2, EE<E2>>>::setup(&pp, &initial_table)
.unwrap();

// produce a compressed SNARK
let res = CompressedSNARKV2::<_, _, _, _, S1<E1, EE<E1>>, S2<E2, EE<E2>>>::prove(
&pp,
Expand All @@ -2446,6 +2452,7 @@ mod tests {
write_row,
(alpha, gamma),
);
println!("res: {:?}", res);
assert!(res.is_ok());
}
}
Loading

0 comments on commit 3d2313d

Please sign in to comment.