diff --git a/recursion/compiler/src/ir/bits.rs b/recursion/compiler/src/ir/bits.rs index 7df2f0e5e5..f69c8cee1d 100644 --- a/recursion/compiler/src/ir/bits.rs +++ b/recursion/compiler/src/ir/bits.rs @@ -1,11 +1,14 @@ -use p3_field::AbstractField; +use p3_field::{AbstractField, Field}; use sp1_recursion_core::runtime::NUM_BITS; use super::{Array, Builder, Config, DslIr, Felt, Usize, Var}; impl Builder { - /// Converts a variable to bits. + /// Converts a variable to LE bits. pub fn num2bits_v(&mut self, num: Var) -> Array> { + // This function is only used when the native field is Babybear. + assert!(C::N::bits() == NUM_BITS); + let output = self.dyn_array::>(NUM_BITS); self.push(DslIr::HintBitsV(output.clone(), num)); @@ -16,10 +19,10 @@ impl Builder { self.assign(sum, sum + bit * C::N::from_canonical_u32(1 << i)); } - // TODO: There is an edge case where the witnessed bits may slightly overflow and cause - // the output to be incorrect. This is a known issue and will be fixed in the future. self.assert_var_eq(sum, num); + self.less_than_bb_modulus(output.clone()); + output } @@ -49,22 +52,25 @@ impl Builder { }); } - // TODO: There is an edge case where the witnessed bits may slightly overflow and cause - // the output to be incorrect. This is a known issue and will be fixed in the future. self.assert_felt_eq(sum, num); + self.less_than_bb_modulus(output.clone()); + output } /// Converts a felt to bits inside a circuit. pub fn num2bits_f_circuit(&mut self, num: Felt) -> Vec> { let mut output = Vec::new(); - for _ in 0..32 { + for _ in 0..NUM_BITS { output.push(self.uninit()); } self.push(DslIr::CircuitNum2BitsF(num, output.clone())); + let output_array = self.vec(output.clone()); + self.less_than_bb_modulus(output_array); + output } @@ -149,4 +155,33 @@ impl Builder { } result_bits } + + /// Checks that the LE bit decomposition of a number is less than the babybear modulus. + /// + /// SAFETY: This function assumes that the num_bits values are already verified to be boolean. + /// + /// The babybear modulus in LE bits is: 100_000_000_000_000_000_000_000_000_111_1. + /// To check that the num_bits array is less than that value, we first check if the most significant + /// bits are all 1. If it is, then we assert that the other bits are all 0. + fn less_than_bb_modulus(&mut self, num_bits: Array>) { + let one: Var<_> = self.eval(C::N::one()); + let zero: Var<_> = self.eval(C::N::zero()); + + let mut most_sig_4_bits = one; + for i in (NUM_BITS - 4)..NUM_BITS { + let bit = self.get(&num_bits, i); + most_sig_4_bits = self.eval(bit * most_sig_4_bits); + } + + let mut sum_least_sig_bits = zero; + for i in 0..(NUM_BITS - 4) { + let bit = self.get(&num_bits, i); + sum_least_sig_bits = self.eval(bit + sum_least_sig_bits); + } + + // If the most significant 4 bits are all 1, then check the sum of the least significant bits, else return zero. + let check: Var<_> = + self.eval(most_sig_4_bits * sum_least_sig_bits + (one - most_sig_4_bits) * zero); + self.assert_var_eq(check, zero); + } } diff --git a/recursion/core/src/runtime/mod.rs b/recursion/core/src/runtime/mod.rs index 77a118326a..a37ce6858c 100644 --- a/recursion/core/src/runtime/mod.rs +++ b/recursion/core/src/runtime/mod.rs @@ -697,7 +697,7 @@ where // Get the src value. let num = b_val[0].as_canonical_u32(); - // Decompose the num into bits. + // Decompose the num into LE bits. let bits = (0..NUM_BITS).map(|i| (num >> i) & 1).collect::>(); // Write the bits to the array at dst. for (i, bit) in bits.iter().enumerate() {