Skip to content

Commit

Permalink
Fix signed saturating math functions
Browse files Browse the repository at this point in the history
  • Loading branch information
PlasmaPower committed May 16, 2024
1 parent 6f73839 commit 4f72ebb
Show file tree
Hide file tree
Showing 3 changed files with 250 additions and 15 deletions.
44 changes: 29 additions & 15 deletions util/arbmath/math.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,6 @@ func MaxInt[T Number](values ...T) T {
return max
}

// AbsValue the absolute value of a number
func AbsValue[T Number](value T) T {
if value < 0 {
return -value // never happens for unsigned types
}
return value
}

// Checks if two ints are sufficiently close to one another
func Within[T Unsigned](a, b, bound T) bool {
min := MinInt(a, b)
Expand Down Expand Up @@ -267,14 +259,32 @@ func BigFloatMulByUint(multiplicand *big.Float, multiplier uint64) *big.Float {
return new(big.Float).Mul(multiplicand, UintToBigFloat(multiplier))
}

func MaxIntValue[T Integer]() T {
allBits := ^T(0)
if allBits < 0 {
// This is a signed integer
return T((uint64(1) << (8*unsafe.Sizeof(allBits) - 1)) - 1)
}
return allBits
}

func MinIntValue[T Integer]() T {
allBits := ^T(0)
if allBits < 0 {
// This is a signed integer
return T(uint64(1) << ((8 * unsafe.Sizeof(allBits)) - 1))
}
return 0
}

// SaturatingAdd add two integers without overflow
func SaturatingAdd[T Signed](a, b T) T {
sum := a + b
if b > 0 && sum < a {
sum = ^T(0) >> 1
sum = MaxIntValue[T]()
}
if b < 0 && sum > a {
sum = (^T(0) >> 1) + 1
sum = MinIntValue[T]()
}
return sum
}
Expand All @@ -290,7 +300,11 @@ func SaturatingUAdd[T Unsigned](a, b T) T {

// SaturatingSub subtract an int64 from another without overflow
func SaturatingSub(minuend, subtrahend int64) int64 {
return SaturatingAdd(minuend, -subtrahend)
if subtrahend == math.MinInt64 {
// The absolute value of MinInt64 is one greater than MaxInt64
return SaturatingAdd(SaturatingAdd(minuend, math.MaxInt64), 1)
}
return SaturatingAdd(minuend, SaturatingNeg(subtrahend))
}

// SaturatingUSub subtract an integer from another without underflow
Expand All @@ -315,9 +329,9 @@ func SaturatingMul[T Signed](a, b T) T {
product := a * b
if b != 0 && product/b != a {
if (a > 0 && b > 0) || (a < 0 && b < 0) {
product = ^T(0) >> 1
product = MaxIntValue[T]()
} else {
product = (^T(0) >> 1) + 1
product = MinIntValue[T]()
}
}
return product
Expand Down Expand Up @@ -367,8 +381,8 @@ func SaturatingCastToUint(value *big.Int) uint64 {

// Negates an int without underflow
func SaturatingNeg[T Signed](value T) T {
if value == ^T(0) {
return (^T(0)) >> 1
if value < 0 && value == MinIntValue[T]() {
return MaxIntValue[T]()
}
return -value
}
Expand Down
112 changes: 112 additions & 0 deletions util/arbmath/math_fuzz_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// Copyright 2024, Offchain Labs, Inc.
// For license information, see https://github.com/nitro/blob/master/LICENSE

package arbmath

import (
"math/big"
"testing"
)

func toBig[T Signed](a T) *big.Int {
return big.NewInt(int64(a))
}

func saturatingBigToInt[T Signed](a *big.Int) T {
// MinIntValue and MaxIntValue are already separately tested
if a.Cmp(toBig(MaxIntValue[T]())) > 0 {
return MaxIntValue[T]()
}
if a.Cmp(toBig(MinIntValue[T]())) < 0 {
return MinIntValue[T]()
}
return T(a.Int64())
}

func fuzzSaturatingAdd[T Signed](f *testing.F) {
f.Fuzz(func(t *testing.T, a, b T) {
got := SaturatingAdd(a, b)
expected := saturatingBigToInt[T](new(big.Int).Add(toBig(a), toBig(b)))
if got != expected {
t.Errorf("SaturatingAdd(%v, %v) = %v, expected %v", a, b, got, expected)
}
})
}

func fuzzSaturatingMul[T Signed](f *testing.F) {
f.Fuzz(func(t *testing.T, a, b T) {
got := SaturatingMul(a, b)
expected := saturatingBigToInt[T](new(big.Int).Mul(toBig(a), toBig(b)))
if got != expected {
t.Errorf("SaturatingMul(%v, %v) = %v, expected %v", a, b, got, expected)
}
})
}

func fuzzSaturatingNeg[T Signed](f *testing.F) {
f.Fuzz(func(t *testing.T, a T) {
got := SaturatingNeg(a)
expected := saturatingBigToInt[T](new(big.Int).Neg(toBig(a)))
if got != expected {
t.Errorf("SaturatingNeg(%v) = %v, expected %v", a, got, expected)
}
})
}

func FuzzSaturatingAddInt8(f *testing.F) {
fuzzSaturatingAdd[int8](f)
}

func FuzzSaturatingAddInt16(f *testing.F) {
fuzzSaturatingAdd[int16](f)
}

func FuzzSaturatingAddInt32(f *testing.F) {
fuzzSaturatingAdd[int32](f)
}

func FuzzSaturatingAddInt64(f *testing.F) {
fuzzSaturatingAdd[int64](f)
}

func FuzzSaturatingSub(f *testing.F) {
f.Fuzz(func(t *testing.T, a, b int64) {
got := SaturatingSub(a, b)
expected := saturatingBigToInt[int64](new(big.Int).Sub(toBig(a), toBig(b)))
if got != expected {
t.Errorf("SaturatingSub(%v, %v) = %v, expected %v", a, b, got, expected)
}
})
}

func FuzzSaturatingMulInt8(f *testing.F) {
fuzzSaturatingMul[int8](f)
}

func FuzzSaturatingMulInt16(f *testing.F) {
fuzzSaturatingMul[int16](f)
}

func FuzzSaturatingMulInt32(f *testing.F) {
fuzzSaturatingMul[int32](f)
}

func FuzzSaturatingMulInt64(f *testing.F) {
fuzzSaturatingMul[int64](f)
}

func FuzzSaturatingNegInt8(f *testing.F) {
fuzzSaturatingNeg[int8](f)
}

func FuzzSaturatingNegInt16(f *testing.F) {
fuzzSaturatingNeg[int16](f)
}

func FuzzSaturatingNegInt32(f *testing.F) {
fuzzSaturatingNeg[int32](f)
}

func FuzzSaturatingNegInt64(f *testing.F) {
fuzzSaturatingNeg[int64](f)
}
109 changes: 109 additions & 0 deletions util/arbmath/math_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package arbmath

import (
"bytes"
"fmt"
"math"
"math/rand"
"testing"
Expand Down Expand Up @@ -120,6 +121,114 @@ func TestSlices(t *testing.T) {
assert_eq(SliceWithRunoff(data, 7, 8), []uint8{})
}

func testMinMaxValues[T Integer](t *testing.T, min T, max T) {
gotMin := MinIntValue[T]()
if gotMin != min {
Fail(t, "expected min", min, "but got", gotMin)
}
gotMax := MaxIntValue[T]()
if gotMax != max {
Fail(t, "expected max", max, "but got", gotMax)
}
}

func TestMinMaxValues(t *testing.T) {
testMinMaxValues[uint8](t, 0, math.MaxUint8)
testMinMaxValues[uint16](t, 0, math.MaxUint16)
testMinMaxValues[uint32](t, 0, math.MaxUint32)
testMinMaxValues[uint64](t, 0, math.MaxUint64)
testMinMaxValues[int8](t, math.MinInt8, math.MaxInt8)
testMinMaxValues[int16](t, math.MinInt16, math.MaxInt16)
testMinMaxValues[int32](t, math.MinInt32, math.MaxInt32)
testMinMaxValues[int64](t, math.MinInt64, math.MaxInt64)
}

func TestSaturatingAdd(t *testing.T) {
tests := []struct {
a, b, expected int64
}{
{2, 3, 5},
{-1, -2, -3},
{math.MaxInt64, 1, math.MaxInt64},
{math.MaxInt64, math.MaxInt64, math.MaxInt64},
{math.MinInt64, -1, math.MinInt64},
{math.MinInt64, math.MinInt64, math.MinInt64},
}

for _, tc := range tests {
t.Run(fmt.Sprintf("%v + %v = %v", tc.a, tc.b, tc.expected), func(t *testing.T) {
sum := SaturatingAdd(int64(tc.a), int64(tc.b))
if sum != tc.expected {
t.Errorf("SaturatingAdd(%v, %v) = %v; want %v", tc.a, tc.b, sum, tc.expected)
}
})
}
}

func TestSaturatingSub(t *testing.T) {
tests := []struct {
a, b, expected int64
}{
{5, 3, 2},
{-3, -2, -1},
{math.MinInt64, 1, math.MinInt64},
{math.MinInt64, -1, math.MinInt64 + 1},
{math.MinInt64, math.MinInt64, 0},
{0, math.MinInt64, math.MaxInt64},
}

for _, tc := range tests {
t.Run("", func(t *testing.T) {
sum := SaturatingSub(int64(tc.a), int64(tc.b))
if sum != tc.expected {
t.Errorf("SaturatingSub(%v, %v) = %v; want %v", tc.a, tc.b, sum, tc.expected)
}
})
}
}

func TestSaturatingMul(t *testing.T) {
tests := []struct {
a, b, expected int64
}{
{5, 3, 15},
{-3, -2, 6},
{math.MaxInt64, 2, math.MaxInt64},
{math.MinInt64, 2, math.MinInt64},
}

for _, tc := range tests {
t.Run(fmt.Sprintf("%v - %v = %v", tc.a, tc.b, tc.expected), func(t *testing.T) {
sum := SaturatingMul(int64(tc.a), int64(tc.b))
if sum != tc.expected {
t.Errorf("SaturatingMul(%v, %v) = %v; want %v", tc.a, tc.b, sum, tc.expected)
}
})
}
}

func TestSaturatingNeg(t *testing.T) {
tests := []struct {
value int64
expected int64
}{
{0, 0},
{5, -5},
{-5, 5},
{math.MinInt64, math.MaxInt64},
{math.MaxInt64, math.MinInt64 + 1},
}

for _, tc := range tests {
t.Run(fmt.Sprintf("-%v = %v", tc.value, tc.expected), func(t *testing.T) {
result := SaturatingNeg(tc.value)
if result != tc.expected {
t.Errorf("SaturatingNeg(%v) = %v: expected %v", tc.value, result, tc.expected)
}
})
}
}

func Fail(t *testing.T, printables ...interface{}) {
t.Helper()
testhelpers.FailImpl(t, printables...)
Expand Down

0 comments on commit 4f72ebb

Please sign in to comment.