Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] first pass at framing arithmetic #10

Merged
merged 7 commits into from
Sep 9, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 159 additions & 84 deletions src/finite_field.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use num_traits::{Inv, Zero};
use std::ops::{Add, Div, Mul, Neg, Sub};

/// A finite field scalar optimized for use in cryptographic operations.
///
/// All operations feature modular arithmetic, implemented in constant time.
Expand Down Expand Up @@ -28,16 +27,18 @@ impl<const L: usize, const D: usize> FinitePrimeField<L, D> {
}
// TODO(Cache these for a given modulus for the lifetime of the program)
// If it can be done in a way which doesn't introduce side-channel attacks
let correction = Self::subtraction_correction(&modulus);
let r_squared = Self::compute_r_squared(&modulus);
let n_prime = Self::compute_n_prime(&modulus);
Self {
let r = Self::compute_r(&modulus);
let mut retval = Self {
modulus,
value,
correction,
r_squared,
n_prime,
}
correction: [0u64; L],
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These values could be computed beforehand rather than having a mutation occur.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, this was left over from before I moved arithmetic to _internals, will remedy

r_squared: [0u64; L],
n_prime: 0u64,
};
retval.correction = Self::subtraction_correction(&modulus);
retval.n_prime = Self::compute_n_prime(&modulus);
retval.r_squared = retval.montgomery_multiply(&r, &r);
retval
}

/// Computes the correction factor for efficient subtraction.
Expand Down Expand Up @@ -86,33 +87,40 @@ impl<const L: usize, const D: usize> FinitePrimeField<L, D> {
}

const fn compute_r(modulus: &[u64; L]) -> [u64; L] {
// TODO (Implement Montgomery r squared)
Self::zero_array()
}
let diff = Self::subtraction_correction(modulus);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be passed in to the function rather than recomputed.

let mut r = [0u64; L];

const fn compute_r_squared(modulus: &[u64; L]) -> [u64; L] {
// TODO (Implement Montgomery r squared)
Self::zero_array()
}

const fn compute_r_cubed(modulus: &[u64; L]) -> [u64; L] {
// TODO (Implement Montgomery r squared)
Self::zero_array()
//R = (2^(64*L) - modulus) + modulus
0xAlcibiades marked this conversation as resolved.
Show resolved Hide resolved
let mut i = 0;
let mut carry = 0u64;
while i < L {
let (sum, c1) = diff[i].overflowing_add(modulus[i]);
let (sum, c2 ) = sum.overflowing_add(carry);
r[i] = sum;
carry = (c1 as u64) + (c2 as u64);
i+= 1;
}
r //by definition in range, so no need to montgomery reduce
}

const fn compute_n_prime(modulus: &[u64; L]) -> u64 {
// TODO (Implement Montgomery n prime)
0u64
let n = modulus[0]; //need only least significant bits
let mut n_prime = 1u64;
let mut i = 0;
while i < 64 {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this number, because 64 bits?

n_prime = n_prime.wrapping_mul(n);
n_prime = n_prime.wrapping_mul(2u64.wrapping_sub(n.wrapping_mul(n_prime)));
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this number?

i += 1;
}
n_prime.wrapping_neg()
}

pub const fn to_montgomery(&self, a: &[u64; L]) -> [u64; L] {
// TODO (Implement to monty form)
Self::zero_array()
self.montgomery_multiply(a, &self.r_squared)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some detail on why multiplying by r_squared gives us the monty form would be helpful.

}

pub const fn from_montgomery(&self, a: &[u64; L]) -> [u64; L] {
// TODO (Implement from monty form)
Self::zero_array()
self.montgomery_multiply(a, &Self::one_array())
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here on why multiplying by 1 (the identity, not an array filled with ones) gives us the normal form.

}

/// Performs Montgomery multiplication of two large integers represented as arrays of u64.
Expand Down Expand Up @@ -198,110 +206,177 @@ impl<const L: usize, const D: usize> FinitePrimeField<L, D> {
j += 1;
}

result
self.to_montgomery(&result)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why here do we convert to_montgomery? Should this not already be in monty form at the end of this, or has it been reduced?

}

pub const fn bernstein_yang_invert(&self, a: &[u64; L]) -> [u64; L] {
// TODO: implement bernstein yang inversion
Self::zero_array()

pub const fn greater_than(&self, a: &[u64; L], b:&[u64;L])-> bool {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not truly constant time as different branches on the comparator may return at different times, we should instead always run through the whole loop and then return only afterwards.

let mut i = L;
while i>0 {
i -= 1;
if a[i] > b[i] {
return true;
}
if a[i] < b[i] {
return false;
}
}
false
}
pub const fn is_zero(&self, a: &[u64; L]) -> bool {
let mut retval = true;
let mut i = 0;
while i < L {
retval &= a[i] == 0;
i += 1;
}
retval
}
}

impl<const L: usize, const D: usize> Add for FinitePrimeField<L, D> {
type Output = Self;

/// Performs modular addition.
///
/// This method adds two field elements and reduces the result modulo the field's modulus.
fn add(self, other: Self) -> Self {
pub const fn add_mod_internal(&self, a: &[u64; L], b: &[u64; L]) -> [u64;L] {
// Initialize sum to zero
let mut sum = Self::new(self.modulus, Self::zero_array());
let mut sum = Self::zero_array();
let mut carry = false;
let mut i = 0;

// Perform addition with carry propagation
for i in 0..L {
let sum_with_other = self.value[i].overflowing_add(other.value[i]);
while i < L {
let sum_with_other = a[i].overflowing_add(b[i]);
let sum_with_carry = sum_with_other.0.overflowing_add(if carry { 1 } else { 0 });
sum.value[i] = sum_with_carry.0;
sum[i] = sum_with_carry.0;
carry = sum_with_other.1 | sum_with_carry.1;
i += 1;
}

// Perform trial subtraction of modulus
let mut trial = Self::new(self.modulus, Self::zero_array());
i = 0;
let mut trial = Self::zero_array();
let mut borrow = false;
for i in 0..L {
while i < L {
// Note: a single overflowing_sub is enough because modulus[i]+borrow can never overflow
let diff_with_borrow =
sum.value[i].overflowing_sub(self.modulus[i] + if borrow { 1 } else { 0 });
trial.value[i] = diff_with_borrow.0;
sum[i].overflowing_sub(self.modulus[i] + if borrow { 1 } else { 0 });
trial[i] = diff_with_borrow.0;
borrow = diff_with_borrow.1;
i += 1;
}

// Select between sum and trial based on borrow flag
let mut result = Self::new(self.modulus, Self::zero_array());
i = 0;
let mut result = Self::zero_array();
let select_mask = u64::from(borrow).wrapping_neg();
for i in 0..L {
while i < L {
// If borrow is true (select_mask is all 1s), choose sum, otherwise choose trial
result.value[i] = (select_mask & sum.value[i]) | (!select_mask & trial.value[i]);
result[i] = (select_mask & sum[i]) | (!select_mask & trial[i]);
i += 1;
}
result
}
}

impl<const L: usize, const D: usize> Neg for FinitePrimeField<L, D> {
type Output = Self;

fn neg(self) -> Self {
let zero = Self::new(self.modulus, Self::zero_array());
let z = self == zero;
let mut negated = Self::new(self.modulus, Self::zero_array());
for i in 0..L {
negated.value[i] = self.modulus[i].wrapping_sub(self.value[i]);
}
if z {
zero
} else {
negated
}
}
}

impl<const L: usize, const D: usize> Sub for FinitePrimeField<L, D> {
type Output = Self;

/// Performs modular subtraction.
///
/// This method subtracts one field element from another and ensures the result
/// is in the correct range by adding the modulus if necessary.
fn sub(self, other: Self) -> Self {
pub const fn sub_internal(&self, a: &[u64; L], b: &[u64; L]) -> [u64; L] {
// Initialize difference to zero
let mut difference = Self::new(self.modulus, Self::zero_array());
let mut difference = Self::zero_array();
let mut borrow = false;
let mut i = 0;

// Perform subtraction with borrow propagation
for i in 0..L {
let diff_without_borrow = self.value[i].overflowing_sub(other.value[i]);
while i < L {
let diff_without_borrow = a[i].overflowing_sub(b[i]);
let diff_with_borrow =
diff_without_borrow
.0
.overflowing_sub(if borrow { 1 } else { 0 });
difference.value[i] = diff_with_borrow.0;
difference[i] = diff_with_borrow.0;
borrow = diff_without_borrow.1 | diff_with_borrow.1;
i += 1;
}

// Always subtract the correction, which effectively adds the modulus if borrow occurred
let correction_mask = u64::from(borrow).wrapping_neg();
let mut correction_borrow = false;
for i in 0..L {
i = 0;
while i < L {
let correction_term =
(correction_mask & self.correction[i]) + if correction_borrow { 1 } else { 0 };
let (corrected_limb, new_borrow) = difference.value[i].overflowing_sub(correction_term);
difference.value[i] = corrected_limb;
let (corrected_limb, new_borrow) = difference[i].overflowing_sub(correction_term);
difference[i] = corrected_limb;
correction_borrow = new_borrow;
i += 1;
}

difference
}

pub const fn neg_internal(&self, a: &[u64; L]) -> [u64; L] {
let zero = Self::zero_array();
let z = self.is_zero(a);
let mut negated = Self::zero_array();
let mut i = 0;
while i < L {
negated[i] = self.modulus[i].wrapping_sub(a[i]);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking some carry about the modulus needs to happen here.

i += 1;
}
if z {
zero
} else {
negated
}
}
/// The following performs the Bernstein-Yang inversion on a scalar.
pub const fn bernstein_yang_invert(&self, a: &[u64; L]) -> [u64; L] {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect use case for dimensionals 💯

let mut u = *a;
let mut v = self.modulus;
let mut r = Self::zero_array();
let mut s = Self::one_array();

let mut i = 0;
while i < 256*L {
if self.is_zero(&v) {
break;
}
if self.is_even(&u){
u = self.div2(&u);
s = self.mul2(&s);
} else if self.is_even(&v){
v = self.div2(&v);
r = self.mul2(&r);
} else if self.greater_than(&u, &v) {

}
i+= 1;
}
Self::zero_array()
}
}

impl<const L: usize, const D: usize> Add for FinitePrimeField<L, D> {
type Output = Self;

/// Performs modular addition.
///
/// This method adds two field elements and reduces the result modulo the field's modulus.
fn add(self, other: Self) -> Self {
Self::new(self.modulus, self.add_mod_internal(&self.value, &other.value))
}
}

impl<const L: usize, const D: usize> Neg for FinitePrimeField<L, D> {
type Output = Self;

fn neg(self) -> Self {
Self::new(self.modulus, self.neg_internal(&self.value))
}
}

impl<const L: usize, const D: usize> Sub for FinitePrimeField<L, D> {
type Output = Self;

/// Performs modular subtraction.
///
/// This method subtracts one field element from another and ensures the result
/// is in the correct range by adding the modulus if necessary.
fn sub(self, other: Self) -> Self {
Self::new(self.modulus, self.sub_internal(&self.value,&other.value))
}
}

// TODO(Make this constant time)
Expand Down
Loading