Skip to content

Commit

Permalink
Fix u128 log rounding errors, and overflow reverts (#6647)
Browse files Browse the repository at this point in the history
## Description
Same as #6163 but for u128

Ensures the reverts on overflow and unsafe math take into account the
flags set

## Checklist

- [ ] I have linked to any relevant issues.
- [ ] I have commented my code, particularly in hard-to-understand
areas.
- [ ] I have updated the documentation where relevant (API docs, the
reference, and the Sway book).
- [ ] If my change requires substantial documentation changes, I have
[requested support from the DevRel
team](https://github.com/FuelLabs/devrel-requests/issues/new/choose)
- [ ] I have added tests that prove my fix is effective or that my
feature works.
- [ ] I have added (or requested a maintainer to add) the necessary
`Breaking*` or `New Feature` labels where relevant.
- [ ] I have done my best to ensure that my PR adheres to [the Fuel Labs
Code Review
Standards](https://github.com/FuelLabs/rfcs/blob/master/text/code-standards/external-contributors.md).
- [ ] I have requested a review from the relevant team or maintainers.

---------

Co-authored-by: K1-R1 <[email protected]>
  • Loading branch information
SwayStar123 and K1-R1 authored Oct 16, 2024
1 parent 98e093f commit 9f22fe8
Show file tree
Hide file tree
Showing 3 changed files with 494 additions and 25 deletions.
161 changes: 145 additions & 16 deletions sway-lib-std/src/u128.sw
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,16 @@ library;

use ::assert::assert;
use ::convert::{From, Into};
use ::flags::{disable_panic_on_overflow, set_flags};
use ::flags::{
disable_panic_on_overflow,
panic_on_overflow_enabled,
panic_on_unsafe_math_enabled,
set_flags,
};
use ::registers::{flags, overflow};
use ::math::*;
use ::result::Result::{self, *};
use ::option::Option::{self, None, Some};

/// The 128-bit unsigned integer type.
///
Expand Down Expand Up @@ -550,8 +557,11 @@ impl core::ops::Add for U128 {
fn add(self, other: Self) -> Self {
let mut upper_128 = self.upper.overflowing_add(other.upper);

// If the upper overflows, then the number cannot fit in 128 bits, so panic.
assert(upper_128.upper == 0);
if panic_on_overflow_enabled() {
// If the upper overflows, then the number cannot fit in 128 bits, so panic.
assert(upper_128.upper == 0);
}

let lower_128 = self.lower.overflowing_add(other.lower);

// If overflow has occurred in the lower component addition, carry.
Expand All @@ -560,8 +570,10 @@ impl core::ops::Add for U128 {
upper_128 = upper_128.lower.overflowing_add(lower_128.upper);
}

// If overflow has occurred in the upper component addition, panic.
assert(upper_128.upper == 0);
if panic_on_overflow_enabled() {
// If overflow has occurred in the upper component addition, panic.
assert(upper_128.upper == 0);
}

Self {
upper: upper_128.lower,
Expand All @@ -571,10 +583,13 @@ impl core::ops::Add for U128 {
}

impl core::ops::Subtract for U128 {
/// Subtract a `U128` from a `U128`. Reverts of overflow.
/// Subtract a `U128` from a `U128`. Reverts on underflow.
fn subtract(self, other: Self) -> Self {
// If trying to subtract a larger number, panic.
assert(!(self < other));
// panic_on_overflow_enabled is also for underflow
if panic_on_overflow_enabled() {
// If trying to subtract a larger number, panic.
assert(!(self < other));
}

let mut upper = self.upper - other.upper;
let mut lower = 0;
Expand All @@ -596,7 +611,9 @@ impl core::ops::Multiply for U128 {
// in case both of the `U128` upper parts are bigger than zero,
// it automatically means overflow, as any `U128` value
// is upper part multiplied by 2 ^ 64 + lower part
assert(self.upper == 0 || other.upper == 0);
if panic_on_unsafe_math_enabled() {
assert(self.upper == 0 || other.upper == 0);
}

let mut result = self.lower.overflowing_mul(other.lower);
if self.upper == 0 {
Expand All @@ -616,7 +633,9 @@ impl core::ops::Divide for U128 {
fn divide(self, divisor: Self) -> Self {
let zero = Self::from((0, 0));

assert(divisor != zero);
if panic_on_unsafe_math_enabled() {
assert(divisor != zero);
}

if self.upper == 0 && divisor.upper == 0 {
return Self::from((0, self.lower / divisor.lower));
Expand Down Expand Up @@ -645,6 +664,48 @@ impl core::ops::Divide for U128 {
}
}

fn u64_checked_add(a: u64, b: u64) -> Option<u64> {
let of = asm(a: a, b: b, res) {
add res a b;
of: u64
};

if of != 0 {
return None;
}

Some(a + b)
}

fn u128_checked_mul(a: U128, b: U128) -> Option<U128> {
// in case both of the `U128` upper parts are bigger than zero,
// it automatically means overflow, as any `U128` value
// is upper part multiplied by 2 ^ 64 + lower part
if a.upper != 0 || b.upper != 0 {
return None
}

let mut result = a.lower.overflowing_mul(b.lower);

if a.upper == 0 {
match u64_checked_add(result.upper, a.lower * b.upper) {
None => return None,
Some(v) => {
result.upper = v
}
}
} else if b.upper == 0 {
match u64_checked_add(result.upper, a.upper * b.lower) {
None => return None,
Some(v) => {
result.upper = v
}
}
}

Some(result)
}

impl Power for U128 {
fn pow(self, exponent: u32) -> Self {
let mut value = self;
Expand All @@ -661,7 +722,10 @@ impl Power for U128 {
}

while exp & 1 == 0 {
value = value * value;
match u128_checked_mul(value, value) {
None => return U128::zero(),
Some(v) => value = v,
};
exp >>= 1;
}

Expand All @@ -672,9 +736,15 @@ impl Power for U128 {
let mut acc = value;
while exp > 1 {
exp >>= 1;
value = value * value;
match u128_checked_mul(value, value) {
None => return U128::zero(),
Some(v) => value = v,
};
if exp & 1 == 1 {
acc = acc * value;
match u128_checked_mul(acc, value) {
None => return U128::zero(),
Some(v) => acc = v,
};
}
}
acc
Expand Down Expand Up @@ -713,8 +783,14 @@ impl BinaryLogarithm for U128 {
fn log2(self) -> Self {
let zero = Self::from((0, 0));
let mut res = zero;
// If trying to get a log2(0), panic, as infinity is not a number.
assert(self != zero);
// If panic on unsafe math is enabled, only then revert
if panic_on_unsafe_math_enabled() {
assert(self != zero);
} else {
if self == zero {
return zero;
}
}
if self.upper != 0 {
res = Self::from((0, self.upper.log(2) + 64));
} else if self.lower != 0 {
Expand All @@ -726,8 +802,61 @@ impl BinaryLogarithm for U128 {

impl Logarithm for U128 {
fn log(self, base: Self) -> Self {
let flags = disable_panic_on_overflow();

// If panic on unsafe math is enabled, only then revert
if panic_on_unsafe_math_enabled() {
// Logarithm is undefined for bases less than 2
assert(base >= U128::from(2_u64));
// Logarithm is undefined for 0
assert(self != U128::zero());
} else {
// Logarithm is undefined for bases less than 2
// Logarithm is undefined for 0
if (base < U128::from(2_u64)) || (self == U128::zero()) {
set_flags(flags);
return U128::zero();
}
}

// Decimals rounded to 0
if self < base {
set_flags(flags);
return U128::zero();
}

// Estimating the result using change of base formula. Only an estimate because we are doing uint calculations.
let self_log2 = self.log2();
let base_log2 = base.log2();
self_log2 / base_log2
let mut result = (self_log2 / base_log2);

// Converting u128 to u32, this cannot fail as the result will be atmost ~128
let parts: (u64, u64) = result.into();
let res_u32 = asm(r1: parts.1) {
r1: u32
};

// Raising the base to the power of the result
let mut pow_res = base.pow(res_u32);
let mut of = overflow();

// Adjusting the result until the power is less than or equal to self
// If pow_res is > than self, then there is an overestimation. If there is an overflow then there is definitely an overestimation.
while (pow_res > self) || (of > 0) {
result -= U128::from(1_u64);

// Converting u128 to u32, this cannot fail as the result will be atmost ~128
let parts: (u64, u64) = result.into();
let res_u32 = asm(r1: parts.1) {
r1: u32
};

pow_res = base.pow(res_u32);
of = overflow();
};

set_flags(flags);

result
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ fn main() -> bool {
let u_128_8: U128 = U128::from((0, 8));
let u_128_9: U128 = U128::from((0, 9));
let u_128_10: U128 = U128::from((0, 10));
let u_128_21: U128 = U128::from((0, 21));
let u_128_20: U128 = U128::from((0, 20));
let u_128_42: U128 = U128::from((0, 42));
let u_128_64: U128 = U128::from((0, 64));
let u_128_100: U128 = U128::from((0, 100));
let u_128_127: U128 = U128::from((0, 127));
let u_128_max_div_2: U128 = U128::from((1, 0));
let u64_max_times_two: U128 = U128::from((1, 0));
let u_128_max: U128 = U128::max();


Expand All @@ -29,8 +29,8 @@ fn main() -> bool {
assert(u_128_100.log(u_128_9) == u_128_2);
assert(u_128_max.log(u_128_2) == u_128_127);
assert(u_128_max.log(u_128_9) == u_128_42);
assert(u_128_max_div_2.log(u_128_2) == u_128_64);
assert(u_128_max_div_2.log(u_128_9) == u_128_21);
assert(u64_max_times_two.log(u_128_2) == u_128_64);
assert(u64_max_times_two.log(u_128_9) == u_128_20);

true
}
Loading

0 comments on commit 9f22fe8

Please sign in to comment.