Skip to content

Commit

Permalink
fix(recursion): num2bits fixes (#732)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevjue authored May 22, 2024
1 parent 4718a9e commit 5b00b1d
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 8 deletions.
49 changes: 42 additions & 7 deletions recursion/compiler/src/ir/bits.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use p3_field::AbstractField;
use p3_field::{AbstractField, Field};
use sp1_recursion_core::runtime::NUM_BITS;

use super::{Array, Builder, Config, DslIr, Felt, Usize, Var};

impl<C: Config> Builder<C> {
/// Converts a variable to bits.
/// Converts a variable to LE bits.
pub fn num2bits_v(&mut self, num: Var<C::N>) -> Array<C, Var<C::N>> {
// This function is only used when the native field is Babybear.
assert!(C::N::bits() == NUM_BITS);

let output = self.dyn_array::<Var<_>>(NUM_BITS);
self.push(DslIr::HintBitsV(output.clone(), num));

Expand All @@ -16,10 +19,10 @@ impl<C: Config> Builder<C> {
self.assign(sum, sum + bit * C::N::from_canonical_u32(1 << i));
}

// TODO: There is an edge case where the witnessed bits may slightly overflow and cause
// the output to be incorrect. This is a known issue and will be fixed in the future.
self.assert_var_eq(sum, num);

self.less_than_bb_modulus(output.clone());

output
}

Expand Down Expand Up @@ -49,22 +52,25 @@ impl<C: Config> Builder<C> {
});
}

// TODO: There is an edge case where the witnessed bits may slightly overflow and cause
// the output to be incorrect. This is a known issue and will be fixed in the future.
self.assert_felt_eq(sum, num);

self.less_than_bb_modulus(output.clone());

output
}

/// Converts a felt to bits inside a circuit.
pub fn num2bits_f_circuit(&mut self, num: Felt<C::F>) -> Vec<Var<C::N>> {
let mut output = Vec::new();
for _ in 0..32 {
for _ in 0..NUM_BITS {
output.push(self.uninit());
}

self.push(DslIr::CircuitNum2BitsF(num, output.clone()));

let output_array = self.vec(output.clone());
self.less_than_bb_modulus(output_array);

output
}

Expand Down Expand Up @@ -149,4 +155,33 @@ impl<C: Config> Builder<C> {
}
result_bits
}

/// Checks that the LE bit decomposition of a number is less than the babybear modulus.
///
/// SAFETY: This function assumes that the num_bits values are already verified to be boolean.
///
/// The babybear modulus in LE bits is: 100_000_000_000_000_000_000_000_000_111_1.
/// To check that the num_bits array is less than that value, we first check if the most significant
/// bits are all 1. If it is, then we assert that the other bits are all 0.
fn less_than_bb_modulus(&mut self, num_bits: Array<C, Var<C::N>>) {
let one: Var<_> = self.eval(C::N::one());
let zero: Var<_> = self.eval(C::N::zero());

let mut most_sig_4_bits = one;
for i in (NUM_BITS - 4)..NUM_BITS {
let bit = self.get(&num_bits, i);
most_sig_4_bits = self.eval(bit * most_sig_4_bits);
}

let mut sum_least_sig_bits = zero;
for i in 0..(NUM_BITS - 4) {
let bit = self.get(&num_bits, i);
sum_least_sig_bits = self.eval(bit + sum_least_sig_bits);
}

// If the most significant 4 bits are all 1, then check the sum of the least significant bits, else return zero.
let check: Var<_> =
self.eval(most_sig_4_bits * sum_least_sig_bits + (one - most_sig_4_bits) * zero);
self.assert_var_eq(check, zero);
}
}
2 changes: 1 addition & 1 deletion recursion/core/src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ where
// Get the src value.
let num = b_val[0].as_canonical_u32();

// Decompose the num into bits.
// Decompose the num into LE bits.
let bits = (0..NUM_BITS).map(|i| (num >> i) & 1).collect::<Vec<_>>();
// Write the bits to the array at dst.
for (i, bit) in bits.iter().enumerate() {
Expand Down

0 comments on commit 5b00b1d

Please sign in to comment.