diff --git a/util/arbmath/math.go b/util/arbmath/math.go index 1c11c6ad58..d7a0d1f523 100644 --- a/util/arbmath/math.go +++ b/util/arbmath/math.go @@ -74,14 +74,6 @@ func MaxInt[T Number](values ...T) T { return max } -// AbsValue the absolute value of a number -func AbsValue[T Number](value T) T { - if value < 0 { - return -value // never happens for unsigned types - } - return value -} - // Checks if two ints are sufficiently close to one another func Within[T Unsigned](a, b, bound T) bool { min := MinInt(a, b) @@ -267,14 +259,22 @@ func BigFloatMulByUint(multiplicand *big.Float, multiplier uint64) *big.Float { return new(big.Float).Mul(multiplicand, UintToBigFloat(multiplier)) } +func MaxSignedValue[T Signed]() T { + return T((uint64(1) << (8*unsafe.Sizeof(T(0)) - 1)) - 1) +} + +func MinSignedValue[T Signed]() T { + return T(uint64(1) << ((8 * unsafe.Sizeof(T(0))) - 1)) +} + // SaturatingAdd add two integers without overflow func SaturatingAdd[T Signed](a, b T) T { sum := a + b if b > 0 && sum < a { - sum = ^T(0) >> 1 + sum = MaxSignedValue[T]() } if b < 0 && sum > a { - sum = (^T(0) >> 1) + 1 + sum = MinSignedValue[T]() } return sum } @@ -290,7 +290,11 @@ func SaturatingUAdd[T Unsigned](a, b T) T { // SaturatingSub subtract an int64 from another without overflow func SaturatingSub(minuend, subtrahend int64) int64 { - return SaturatingAdd(minuend, -subtrahend) + if subtrahend == math.MinInt64 { + // The absolute value of MinInt64 is one greater than MaxInt64 + return SaturatingAdd(SaturatingAdd(minuend, math.MaxInt64), 1) + } + return SaturatingAdd(minuend, SaturatingNeg(subtrahend)) } // SaturatingUSub subtract an integer from another without underflow @@ -315,9 +319,9 @@ func SaturatingMul[T Signed](a, b T) T { product := a * b if b != 0 && product/b != a { if (a > 0 && b > 0) || (a < 0 && b < 0) { - product = ^T(0) >> 1 + product = MaxSignedValue[T]() } else { - product = (^T(0) >> 1) + 1 + product = MinSignedValue[T]() } } return product @@ -367,8 +371,8 @@ func SaturatingCastToUint(value *big.Int) uint64 { // Negates an int without underflow func SaturatingNeg[T Signed](value T) T { - if value == ^T(0) { - return (^T(0)) >> 1 + if value < 0 && value == MinSignedValue[T]() { + return MaxSignedValue[T]() } return -value } diff --git a/util/arbmath/math_fuzz_test.go b/util/arbmath/math_fuzz_test.go new file mode 100644 index 0000000000..591d699de0 --- /dev/null +++ b/util/arbmath/math_fuzz_test.go @@ -0,0 +1,112 @@ +// Copyright 2024, Offchain Labs, Inc. +// For license information, see https://github.com/nitro/blob/master/LICENSE + +package arbmath + +import ( + "math/big" + "testing" +) + +func toBig[T Signed](a T) *big.Int { + return big.NewInt(int64(a)) +} + +func saturatingBigToInt[T Signed](a *big.Int) T { + // MinSignedValue and MaxSignedValue are already separately tested + if a.Cmp(toBig(MaxSignedValue[T]())) > 0 { + return MaxSignedValue[T]() + } + if a.Cmp(toBig(MinSignedValue[T]())) < 0 { + return MinSignedValue[T]() + } + return T(a.Int64()) +} + +func fuzzSaturatingAdd[T Signed](f *testing.F) { + f.Fuzz(func(t *testing.T, a, b T) { + got := SaturatingAdd(a, b) + expected := saturatingBigToInt[T](new(big.Int).Add(toBig(a), toBig(b))) + if got != expected { + t.Errorf("SaturatingAdd(%v, %v) = %v, expected %v", a, b, got, expected) + } + }) +} + +func fuzzSaturatingMul[T Signed](f *testing.F) { + f.Fuzz(func(t *testing.T, a, b T) { + got := SaturatingMul(a, b) + expected := saturatingBigToInt[T](new(big.Int).Mul(toBig(a), toBig(b))) + if got != expected { + t.Errorf("SaturatingMul(%v, %v) = %v, expected %v", a, b, got, expected) + } + }) +} + +func fuzzSaturatingNeg[T Signed](f *testing.F) { + f.Fuzz(func(t *testing.T, a T) { + got := SaturatingNeg(a) + expected := saturatingBigToInt[T](new(big.Int).Neg(toBig(a))) + if got != expected { + t.Errorf("SaturatingNeg(%v) = %v, expected %v", a, got, expected) + } + }) +} + +func FuzzSaturatingAddInt8(f *testing.F) { + fuzzSaturatingAdd[int8](f) +} + +func FuzzSaturatingAddInt16(f *testing.F) { + fuzzSaturatingAdd[int16](f) +} + +func FuzzSaturatingAddInt32(f *testing.F) { + fuzzSaturatingAdd[int32](f) +} + +func FuzzSaturatingAddInt64(f *testing.F) { + fuzzSaturatingAdd[int64](f) +} + +func FuzzSaturatingSub(f *testing.F) { + f.Fuzz(func(t *testing.T, a, b int64) { + got := SaturatingSub(a, b) + expected := saturatingBigToInt[int64](new(big.Int).Sub(toBig(a), toBig(b))) + if got != expected { + t.Errorf("SaturatingSub(%v, %v) = %v, expected %v", a, b, got, expected) + } + }) +} + +func FuzzSaturatingMulInt8(f *testing.F) { + fuzzSaturatingMul[int8](f) +} + +func FuzzSaturatingMulInt16(f *testing.F) { + fuzzSaturatingMul[int16](f) +} + +func FuzzSaturatingMulInt32(f *testing.F) { + fuzzSaturatingMul[int32](f) +} + +func FuzzSaturatingMulInt64(f *testing.F) { + fuzzSaturatingMul[int64](f) +} + +func FuzzSaturatingNegInt8(f *testing.F) { + fuzzSaturatingNeg[int8](f) +} + +func FuzzSaturatingNegInt16(f *testing.F) { + fuzzSaturatingNeg[int16](f) +} + +func FuzzSaturatingNegInt32(f *testing.F) { + fuzzSaturatingNeg[int32](f) +} + +func FuzzSaturatingNegInt64(f *testing.F) { + fuzzSaturatingNeg[int64](f) +} diff --git a/util/arbmath/math_test.go b/util/arbmath/math_test.go index 2e2f14795a..1be60dc58b 100644 --- a/util/arbmath/math_test.go +++ b/util/arbmath/math_test.go @@ -5,6 +5,7 @@ package arbmath import ( "bytes" + "fmt" "math" "math/rand" "testing" @@ -120,6 +121,110 @@ func TestSlices(t *testing.T) { assert_eq(SliceWithRunoff(data, 7, 8), []uint8{}) } +func testMinMaxSignedValues[T Signed](t *testing.T, min T, max T) { + gotMin := MinSignedValue[T]() + if gotMin != min { + Fail(t, "expected min", min, "but got", gotMin) + } + gotMax := MaxSignedValue[T]() + if gotMax != max { + Fail(t, "expected max", max, "but got", gotMax) + } +} + +func TestMinMaxSignedValues(t *testing.T) { + testMinMaxSignedValues[int8](t, math.MinInt8, math.MaxInt8) + testMinMaxSignedValues[int16](t, math.MinInt16, math.MaxInt16) + testMinMaxSignedValues[int32](t, math.MinInt32, math.MaxInt32) + testMinMaxSignedValues[int64](t, math.MinInt64, math.MaxInt64) +} + +func TestSaturatingAdd(t *testing.T) { + tests := []struct { + a, b, expected int64 + }{ + {2, 3, 5}, + {-1, -2, -3}, + {math.MaxInt64, 1, math.MaxInt64}, + {math.MaxInt64, math.MaxInt64, math.MaxInt64}, + {math.MinInt64, -1, math.MinInt64}, + {math.MinInt64, math.MinInt64, math.MinInt64}, + } + + for _, tc := range tests { + t.Run(fmt.Sprintf("%v + %v = %v", tc.a, tc.b, tc.expected), func(t *testing.T) { + sum := SaturatingAdd(int64(tc.a), int64(tc.b)) + if sum != tc.expected { + t.Errorf("SaturatingAdd(%v, %v) = %v; want %v", tc.a, tc.b, sum, tc.expected) + } + }) + } +} + +func TestSaturatingSub(t *testing.T) { + tests := []struct { + a, b, expected int64 + }{ + {5, 3, 2}, + {-3, -2, -1}, + {math.MinInt64, 1, math.MinInt64}, + {math.MinInt64, -1, math.MinInt64 + 1}, + {math.MinInt64, math.MinInt64, 0}, + {0, math.MinInt64, math.MaxInt64}, + } + + for _, tc := range tests { + t.Run("", func(t *testing.T) { + sum := SaturatingSub(int64(tc.a), int64(tc.b)) + if sum != tc.expected { + t.Errorf("SaturatingSub(%v, %v) = %v; want %v", tc.a, tc.b, sum, tc.expected) + } + }) + } +} + +func TestSaturatingMul(t *testing.T) { + tests := []struct { + a, b, expected int64 + }{ + {5, 3, 15}, + {-3, -2, 6}, + {math.MaxInt64, 2, math.MaxInt64}, + {math.MinInt64, 2, math.MinInt64}, + } + + for _, tc := range tests { + t.Run(fmt.Sprintf("%v - %v = %v", tc.a, tc.b, tc.expected), func(t *testing.T) { + sum := SaturatingMul(int64(tc.a), int64(tc.b)) + if sum != tc.expected { + t.Errorf("SaturatingMul(%v, %v) = %v; want %v", tc.a, tc.b, sum, tc.expected) + } + }) + } +} + +func TestSaturatingNeg(t *testing.T) { + tests := []struct { + value int64 + expected int64 + }{ + {0, 0}, + {5, -5}, + {-5, 5}, + {math.MinInt64, math.MaxInt64}, + {math.MaxInt64, math.MinInt64 + 1}, + } + + for _, tc := range tests { + t.Run(fmt.Sprintf("-%v = %v", tc.value, tc.expected), func(t *testing.T) { + result := SaturatingNeg(tc.value) + if result != tc.expected { + t.Errorf("SaturatingNeg(%v) = %v: expected %v", tc.value, result, tc.expected) + } + }) + } +} + func Fail(t *testing.T, printables ...interface{}) { t.Helper() testhelpers.FailImpl(t, printables...)