Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix signed saturating math functions #2306

Merged
merged 2 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 19 additions & 15 deletions util/arbmath/math.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,6 @@ func MaxInt[T Number](values ...T) T {
return max
}

// AbsValue the absolute value of a number
func AbsValue[T Number](value T) T {
if value < 0 {
return -value // never happens for unsigned types
}
return value
}

// Checks if two ints are sufficiently close to one another
func Within[T Unsigned](a, b, bound T) bool {
min := MinInt(a, b)
Expand Down Expand Up @@ -267,14 +259,22 @@ func BigFloatMulByUint(multiplicand *big.Float, multiplier uint64) *big.Float {
return new(big.Float).Mul(multiplicand, UintToBigFloat(multiplier))
}

func MaxSignedValue[T Signed]() T {
return T((uint64(1) << (8*unsafe.Sizeof(T(0)) - 1)) - 1)
}

func MinSignedValue[T Signed]() T {
return T(uint64(1) << ((8 * unsafe.Sizeof(T(0))) - 1))
}

// SaturatingAdd add two integers without overflow
func SaturatingAdd[T Signed](a, b T) T {
sum := a + b
if b > 0 && sum < a {
sum = ^T(0) >> 1
sum = MaxSignedValue[T]()
}
if b < 0 && sum > a {
sum = (^T(0) >> 1) + 1
sum = MinSignedValue[T]()
}
return sum
}
Expand All @@ -290,7 +290,11 @@ func SaturatingUAdd[T Unsigned](a, b T) T {

// SaturatingSub subtract an int64 from another without overflow
func SaturatingSub(minuend, subtrahend int64) int64 {
return SaturatingAdd(minuend, -subtrahend)
if subtrahend == math.MinInt64 {
// The absolute value of MinInt64 is one greater than MaxInt64
return SaturatingAdd(SaturatingAdd(minuend, math.MaxInt64), 1)
}
return SaturatingAdd(minuend, SaturatingNeg(subtrahend))
}

// SaturatingUSub subtract an integer from another without underflow
Expand All @@ -315,9 +319,9 @@ func SaturatingMul[T Signed](a, b T) T {
product := a * b
if b != 0 && product/b != a {
if (a > 0 && b > 0) || (a < 0 && b < 0) {
product = ^T(0) >> 1
product = MaxSignedValue[T]()
} else {
product = (^T(0) >> 1) + 1
product = MinSignedValue[T]()
}
}
return product
Expand Down Expand Up @@ -367,8 +371,8 @@ func SaturatingCastToUint(value *big.Int) uint64 {

// Negates an int without underflow
func SaturatingNeg[T Signed](value T) T {
if value == ^T(0) {
return (^T(0)) >> 1
if value < 0 && value == MinSignedValue[T]() {
return MaxSignedValue[T]()
}
return -value
}
Expand Down
112 changes: 112 additions & 0 deletions util/arbmath/math_fuzz_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// Copyright 2024, Offchain Labs, Inc.
// For license information, see https://github.com/nitro/blob/master/LICENSE

package arbmath

import (
"math/big"
"testing"
)

func toBig[T Signed](a T) *big.Int {
return big.NewInt(int64(a))
}

func saturatingBigToInt[T Signed](a *big.Int) T {
// MinSignedValue and MaxSignedValue are already separately tested
if a.Cmp(toBig(MaxSignedValue[T]())) > 0 {
return MaxSignedValue[T]()
}
if a.Cmp(toBig(MinSignedValue[T]())) < 0 {
return MinSignedValue[T]()
}
return T(a.Int64())
}

func fuzzSaturatingAdd[T Signed](f *testing.F) {
f.Fuzz(func(t *testing.T, a, b T) {
got := SaturatingAdd(a, b)
expected := saturatingBigToInt[T](new(big.Int).Add(toBig(a), toBig(b)))
if got != expected {
t.Errorf("SaturatingAdd(%v, %v) = %v, expected %v", a, b, got, expected)
}
})
}

func fuzzSaturatingMul[T Signed](f *testing.F) {
f.Fuzz(func(t *testing.T, a, b T) {
got := SaturatingMul(a, b)
expected := saturatingBigToInt[T](new(big.Int).Mul(toBig(a), toBig(b)))
if got != expected {
t.Errorf("SaturatingMul(%v, %v) = %v, expected %v", a, b, got, expected)
}
})
}

func fuzzSaturatingNeg[T Signed](f *testing.F) {
f.Fuzz(func(t *testing.T, a T) {
got := SaturatingNeg(a)
expected := saturatingBigToInt[T](new(big.Int).Neg(toBig(a)))
if got != expected {
t.Errorf("SaturatingNeg(%v) = %v, expected %v", a, got, expected)
}
})
}

func FuzzSaturatingAddInt8(f *testing.F) {
fuzzSaturatingAdd[int8](f)
}

func FuzzSaturatingAddInt16(f *testing.F) {
fuzzSaturatingAdd[int16](f)
}

func FuzzSaturatingAddInt32(f *testing.F) {
fuzzSaturatingAdd[int32](f)
}

func FuzzSaturatingAddInt64(f *testing.F) {
fuzzSaturatingAdd[int64](f)
}

func FuzzSaturatingSub(f *testing.F) {
f.Fuzz(func(t *testing.T, a, b int64) {
got := SaturatingSub(a, b)
expected := saturatingBigToInt[int64](new(big.Int).Sub(toBig(a), toBig(b)))
if got != expected {
t.Errorf("SaturatingSub(%v, %v) = %v, expected %v", a, b, got, expected)
}
})
}

func FuzzSaturatingMulInt8(f *testing.F) {
fuzzSaturatingMul[int8](f)
}

func FuzzSaturatingMulInt16(f *testing.F) {
fuzzSaturatingMul[int16](f)
}

func FuzzSaturatingMulInt32(f *testing.F) {
fuzzSaturatingMul[int32](f)
}

func FuzzSaturatingMulInt64(f *testing.F) {
fuzzSaturatingMul[int64](f)
}

func FuzzSaturatingNegInt8(f *testing.F) {
fuzzSaturatingNeg[int8](f)
}

func FuzzSaturatingNegInt16(f *testing.F) {
fuzzSaturatingNeg[int16](f)
}

func FuzzSaturatingNegInt32(f *testing.F) {
fuzzSaturatingNeg[int32](f)
}

func FuzzSaturatingNegInt64(f *testing.F) {
fuzzSaturatingNeg[int64](f)
}
105 changes: 105 additions & 0 deletions util/arbmath/math_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package arbmath

import (
"bytes"
"fmt"
"math"
"math/rand"
"testing"
Expand Down Expand Up @@ -120,6 +121,110 @@ func TestSlices(t *testing.T) {
assert_eq(SliceWithRunoff(data, 7, 8), []uint8{})
}

func testMinMaxSignedValues[T Signed](t *testing.T, min T, max T) {
gotMin := MinSignedValue[T]()
if gotMin != min {
Fail(t, "expected min", min, "but got", gotMin)
}
gotMax := MaxSignedValue[T]()
if gotMax != max {
Fail(t, "expected max", max, "but got", gotMax)
}
}

func TestMinMaxSignedValues(t *testing.T) {
testMinMaxSignedValues[int8](t, math.MinInt8, math.MaxInt8)
testMinMaxSignedValues[int16](t, math.MinInt16, math.MaxInt16)
testMinMaxSignedValues[int32](t, math.MinInt32, math.MaxInt32)
testMinMaxSignedValues[int64](t, math.MinInt64, math.MaxInt64)
}

func TestSaturatingAdd(t *testing.T) {
tests := []struct {
a, b, expected int64
}{
{2, 3, 5},
{-1, -2, -3},
{math.MaxInt64, 1, math.MaxInt64},
{math.MaxInt64, math.MaxInt64, math.MaxInt64},
{math.MinInt64, -1, math.MinInt64},
{math.MinInt64, math.MinInt64, math.MinInt64},
}

for _, tc := range tests {
t.Run(fmt.Sprintf("%v + %v = %v", tc.a, tc.b, tc.expected), func(t *testing.T) {
sum := SaturatingAdd(int64(tc.a), int64(tc.b))
if sum != tc.expected {
t.Errorf("SaturatingAdd(%v, %v) = %v; want %v", tc.a, tc.b, sum, tc.expected)
}
})
}
}

func TestSaturatingSub(t *testing.T) {
tests := []struct {
a, b, expected int64
}{
{5, 3, 2},
{-3, -2, -1},
{math.MinInt64, 1, math.MinInt64},
{math.MinInt64, -1, math.MinInt64 + 1},
{math.MinInt64, math.MinInt64, 0},
{0, math.MinInt64, math.MaxInt64},
}

for _, tc := range tests {
t.Run("", func(t *testing.T) {
sum := SaturatingSub(int64(tc.a), int64(tc.b))
if sum != tc.expected {
t.Errorf("SaturatingSub(%v, %v) = %v; want %v", tc.a, tc.b, sum, tc.expected)
}
})
}
}

func TestSaturatingMul(t *testing.T) {
tests := []struct {
a, b, expected int64
}{
{5, 3, 15},
{-3, -2, 6},
{math.MaxInt64, 2, math.MaxInt64},
{math.MinInt64, 2, math.MinInt64},
}

for _, tc := range tests {
t.Run(fmt.Sprintf("%v - %v = %v", tc.a, tc.b, tc.expected), func(t *testing.T) {
sum := SaturatingMul(int64(tc.a), int64(tc.b))
if sum != tc.expected {
t.Errorf("SaturatingMul(%v, %v) = %v; want %v", tc.a, tc.b, sum, tc.expected)
}
})
}
}

func TestSaturatingNeg(t *testing.T) {
tests := []struct {
value int64
expected int64
}{
{0, 0},
{5, -5},
{-5, 5},
{math.MinInt64, math.MaxInt64},
{math.MaxInt64, math.MinInt64 + 1},
}

for _, tc := range tests {
t.Run(fmt.Sprintf("-%v = %v", tc.value, tc.expected), func(t *testing.T) {
result := SaturatingNeg(tc.value)
if result != tc.expected {
t.Errorf("SaturatingNeg(%v) = %v: expected %v", tc.value, result, tc.expected)
}
})
}
}

func Fail(t *testing.T, printables ...interface{}) {
t.Helper()
testhelpers.FailImpl(t, printables...)
Expand Down
Loading