-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[wip] first pass at framing arithmetic #10
Changes from 1 commit
225efcd
9c9b845
3447322
4fa3faf
29eb0b3
5c14137
ce4b512
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,5 @@ | ||
use num_traits::{Inv, Zero}; | ||
use std::ops::{Add, Div, Mul, Neg, Sub}; | ||
|
||
/// A finite field scalar optimized for use in cryptographic operations. | ||
/// | ||
/// All operations feature modular arithmetic, implemented in constant time. | ||
|
@@ -28,16 +27,18 @@ impl<const L: usize, const D: usize> FinitePrimeField<L, D> { | |
} | ||
// TODO(Cache these for a given modulus for the lifetime of the program) | ||
// If it can be done in a way which doesn't introduce side-channel attacks | ||
let correction = Self::subtraction_correction(&modulus); | ||
let r_squared = Self::compute_r_squared(&modulus); | ||
let n_prime = Self::compute_n_prime(&modulus); | ||
Self { | ||
let r = Self::compute_r(&modulus); | ||
let mut retval = Self { | ||
modulus, | ||
value, | ||
correction, | ||
r_squared, | ||
n_prime, | ||
} | ||
correction: [0u64; L], | ||
r_squared: [0u64; L], | ||
n_prime: 0u64, | ||
}; | ||
retval.correction = Self::subtraction_correction(&modulus); | ||
retval.n_prime = Self::compute_n_prime(&modulus); | ||
retval.r_squared = retval.montgomery_multiply(&r, &r); | ||
retval | ||
} | ||
|
||
/// Computes the correction factor for efficient subtraction. | ||
|
@@ -86,33 +87,40 @@ impl<const L: usize, const D: usize> FinitePrimeField<L, D> { | |
} | ||
|
||
const fn compute_r(modulus: &[u64; L]) -> [u64; L] { | ||
// TODO (Implement Montgomery r squared) | ||
Self::zero_array() | ||
} | ||
let diff = Self::subtraction_correction(modulus); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could be passed in to the function rather than recomputed. |
||
let mut r = [0u64; L]; | ||
|
||
const fn compute_r_squared(modulus: &[u64; L]) -> [u64; L] { | ||
// TODO (Implement Montgomery r squared) | ||
Self::zero_array() | ||
} | ||
|
||
const fn compute_r_cubed(modulus: &[u64; L]) -> [u64; L] { | ||
// TODO (Implement Montgomery r squared) | ||
Self::zero_array() | ||
//R = (2^(64*L) - modulus) + modulus | ||
0xAlcibiades marked this conversation as resolved.
Show resolved
Hide resolved
|
||
let mut i = 0; | ||
let mut carry = 0u64; | ||
while i < L { | ||
let (sum, c1) = diff[i].overflowing_add(modulus[i]); | ||
let (sum, c2 ) = sum.overflowing_add(carry); | ||
r[i] = sum; | ||
carry = (c1 as u64) + (c2 as u64); | ||
i+= 1; | ||
} | ||
r //by definition in range, so no need to montgomery reduce | ||
} | ||
|
||
const fn compute_n_prime(modulus: &[u64; L]) -> u64 { | ||
// TODO (Implement Montgomery n prime) | ||
0u64 | ||
let n = modulus[0]; //need only least significant bits | ||
let mut n_prime = 1u64; | ||
let mut i = 0; | ||
while i < 64 { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this number, because 64 bits? |
||
n_prime = n_prime.wrapping_mul(n); | ||
n_prime = n_prime.wrapping_mul(2u64.wrapping_sub(n.wrapping_mul(n_prime))); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this number? |
||
i += 1; | ||
} | ||
n_prime.wrapping_neg() | ||
} | ||
|
||
pub const fn to_montgomery(&self, a: &[u64; L]) -> [u64; L] { | ||
// TODO (Implement to monty form) | ||
Self::zero_array() | ||
self.montgomery_multiply(a, &self.r_squared) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some detail on why multiplying by r_squared gives us the monty form would be helpful. |
||
} | ||
|
||
pub const fn from_montgomery(&self, a: &[u64; L]) -> [u64; L] { | ||
// TODO (Implement from monty form) | ||
Self::zero_array() | ||
self.montgomery_multiply(a, &Self::one_array()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here on why multiplying by 1 (the identity, not an array filled with ones) gives us the normal form. |
||
} | ||
|
||
/// Performs Montgomery multiplication of two large integers represented as arrays of u64. | ||
|
@@ -198,110 +206,177 @@ impl<const L: usize, const D: usize> FinitePrimeField<L, D> { | |
j += 1; | ||
} | ||
|
||
result | ||
self.to_montgomery(&result) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why here do we convert to_montgomery? Should this not already be in monty form at the end of this, or has it been reduced? |
||
} | ||
|
||
pub const fn bernstein_yang_invert(&self, a: &[u64; L]) -> [u64; L] { | ||
// TODO: implement bernstein yang inversion | ||
Self::zero_array() | ||
|
||
pub const fn greater_than(&self, a: &[u64; L], b:&[u64;L])-> bool { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not truly constant time as different branches on the comparator may return at different times, we should instead always run through the whole loop and then return only afterwards. |
||
let mut i = L; | ||
while i>0 { | ||
i -= 1; | ||
if a[i] > b[i] { | ||
return true; | ||
} | ||
if a[i] < b[i] { | ||
return false; | ||
} | ||
} | ||
false | ||
} | ||
pub const fn is_zero(&self, a: &[u64; L]) -> bool { | ||
let mut retval = true; | ||
let mut i = 0; | ||
while i < L { | ||
retval &= a[i] == 0; | ||
i += 1; | ||
} | ||
retval | ||
} | ||
} | ||
|
||
impl<const L: usize, const D: usize> Add for FinitePrimeField<L, D> { | ||
type Output = Self; | ||
|
||
/// Performs modular addition. | ||
/// | ||
/// This method adds two field elements and reduces the result modulo the field's modulus. | ||
fn add(self, other: Self) -> Self { | ||
pub const fn add_mod_internal(&self, a: &[u64; L], b: &[u64; L]) -> [u64;L] { | ||
// Initialize sum to zero | ||
let mut sum = Self::new(self.modulus, Self::zero_array()); | ||
let mut sum = Self::zero_array(); | ||
let mut carry = false; | ||
let mut i = 0; | ||
|
||
// Perform addition with carry propagation | ||
for i in 0..L { | ||
let sum_with_other = self.value[i].overflowing_add(other.value[i]); | ||
while i < L { | ||
let sum_with_other = a[i].overflowing_add(b[i]); | ||
let sum_with_carry = sum_with_other.0.overflowing_add(if carry { 1 } else { 0 }); | ||
sum.value[i] = sum_with_carry.0; | ||
sum[i] = sum_with_carry.0; | ||
carry = sum_with_other.1 | sum_with_carry.1; | ||
i += 1; | ||
} | ||
|
||
// Perform trial subtraction of modulus | ||
let mut trial = Self::new(self.modulus, Self::zero_array()); | ||
i = 0; | ||
let mut trial = Self::zero_array(); | ||
let mut borrow = false; | ||
for i in 0..L { | ||
while i < L { | ||
// Note: a single overflowing_sub is enough because modulus[i]+borrow can never overflow | ||
let diff_with_borrow = | ||
sum.value[i].overflowing_sub(self.modulus[i] + if borrow { 1 } else { 0 }); | ||
trial.value[i] = diff_with_borrow.0; | ||
sum[i].overflowing_sub(self.modulus[i] + if borrow { 1 } else { 0 }); | ||
trial[i] = diff_with_borrow.0; | ||
borrow = diff_with_borrow.1; | ||
i += 1; | ||
} | ||
|
||
// Select between sum and trial based on borrow flag | ||
let mut result = Self::new(self.modulus, Self::zero_array()); | ||
i = 0; | ||
let mut result = Self::zero_array(); | ||
let select_mask = u64::from(borrow).wrapping_neg(); | ||
for i in 0..L { | ||
while i < L { | ||
// If borrow is true (select_mask is all 1s), choose sum, otherwise choose trial | ||
result.value[i] = (select_mask & sum.value[i]) | (!select_mask & trial.value[i]); | ||
result[i] = (select_mask & sum[i]) | (!select_mask & trial[i]); | ||
i += 1; | ||
} | ||
result | ||
} | ||
} | ||
|
||
impl<const L: usize, const D: usize> Neg for FinitePrimeField<L, D> { | ||
type Output = Self; | ||
|
||
fn neg(self) -> Self { | ||
let zero = Self::new(self.modulus, Self::zero_array()); | ||
let z = self == zero; | ||
let mut negated = Self::new(self.modulus, Self::zero_array()); | ||
for i in 0..L { | ||
negated.value[i] = self.modulus[i].wrapping_sub(self.value[i]); | ||
} | ||
if z { | ||
zero | ||
} else { | ||
negated | ||
} | ||
} | ||
} | ||
|
||
impl<const L: usize, const D: usize> Sub for FinitePrimeField<L, D> { | ||
type Output = Self; | ||
|
||
/// Performs modular subtraction. | ||
/// | ||
/// This method subtracts one field element from another and ensures the result | ||
/// is in the correct range by adding the modulus if necessary. | ||
fn sub(self, other: Self) -> Self { | ||
pub const fn sub_internal(&self, a: &[u64; L], b: &[u64; L]) -> [u64; L] { | ||
// Initialize difference to zero | ||
let mut difference = Self::new(self.modulus, Self::zero_array()); | ||
let mut difference = Self::zero_array(); | ||
let mut borrow = false; | ||
let mut i = 0; | ||
|
||
// Perform subtraction with borrow propagation | ||
for i in 0..L { | ||
let diff_without_borrow = self.value[i].overflowing_sub(other.value[i]); | ||
while i < L { | ||
let diff_without_borrow = a[i].overflowing_sub(b[i]); | ||
let diff_with_borrow = | ||
diff_without_borrow | ||
.0 | ||
.overflowing_sub(if borrow { 1 } else { 0 }); | ||
difference.value[i] = diff_with_borrow.0; | ||
difference[i] = diff_with_borrow.0; | ||
borrow = diff_without_borrow.1 | diff_with_borrow.1; | ||
i += 1; | ||
} | ||
|
||
// Always subtract the correction, which effectively adds the modulus if borrow occurred | ||
let correction_mask = u64::from(borrow).wrapping_neg(); | ||
let mut correction_borrow = false; | ||
for i in 0..L { | ||
i = 0; | ||
while i < L { | ||
let correction_term = | ||
(correction_mask & self.correction[i]) + if correction_borrow { 1 } else { 0 }; | ||
let (corrected_limb, new_borrow) = difference.value[i].overflowing_sub(correction_term); | ||
difference.value[i] = corrected_limb; | ||
let (corrected_limb, new_borrow) = difference[i].overflowing_sub(correction_term); | ||
difference[i] = corrected_limb; | ||
correction_borrow = new_borrow; | ||
i += 1; | ||
} | ||
|
||
difference | ||
} | ||
|
||
pub const fn neg_internal(&self, a: &[u64; L]) -> [u64; L] { | ||
let zero = Self::zero_array(); | ||
let z = self.is_zero(a); | ||
let mut negated = Self::zero_array(); | ||
let mut i = 0; | ||
while i < L { | ||
negated[i] = self.modulus[i].wrapping_sub(a[i]); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm thinking some carry about the modulus needs to happen here. |
||
i += 1; | ||
} | ||
if z { | ||
zero | ||
} else { | ||
negated | ||
} | ||
} | ||
/// The following performs the Bernstein-Yang inversion on a scalar. | ||
pub const fn bernstein_yang_invert(&self, a: &[u64; L]) -> [u64; L] { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perfect use case for dimensionals 💯 |
||
let mut u = *a; | ||
let mut v = self.modulus; | ||
let mut r = Self::zero_array(); | ||
let mut s = Self::one_array(); | ||
|
||
let mut i = 0; | ||
while i < 256*L { | ||
if self.is_zero(&v) { | ||
break; | ||
} | ||
if self.is_even(&u){ | ||
u = self.div2(&u); | ||
s = self.mul2(&s); | ||
} else if self.is_even(&v){ | ||
v = self.div2(&v); | ||
r = self.mul2(&r); | ||
} else if self.greater_than(&u, &v) { | ||
|
||
} | ||
i+= 1; | ||
} | ||
Self::zero_array() | ||
} | ||
} | ||
|
||
impl<const L: usize, const D: usize> Add for FinitePrimeField<L, D> { | ||
type Output = Self; | ||
|
||
/// Performs modular addition. | ||
/// | ||
/// This method adds two field elements and reduces the result modulo the field's modulus. | ||
fn add(self, other: Self) -> Self { | ||
Self::new(self.modulus, self.add_mod_internal(&self.value, &other.value)) | ||
} | ||
} | ||
|
||
impl<const L: usize, const D: usize> Neg for FinitePrimeField<L, D> { | ||
type Output = Self; | ||
|
||
fn neg(self) -> Self { | ||
Self::new(self.modulus, self.neg_internal(&self.value)) | ||
} | ||
} | ||
|
||
impl<const L: usize, const D: usize> Sub for FinitePrimeField<L, D> { | ||
type Output = Self; | ||
|
||
/// Performs modular subtraction. | ||
/// | ||
/// This method subtracts one field element from another and ensures the result | ||
/// is in the correct range by adding the modulus if necessary. | ||
fn sub(self, other: Self) -> Self { | ||
Self::new(self.modulus, self.sub_internal(&self.value,&other.value)) | ||
} | ||
} | ||
|
||
// TODO(Make this constant time) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These values could be computed beforehand rather than having a mutation occur.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed, this was left over from before I moved arithmetic to
_internal
s, will remedy