From c25e56b3c6edce56f59d5a6b5f7b6d622a71110a Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Thu, 17 Oct 2024 04:33:05 -0700 Subject: [PATCH] fix(inclusion): fix integer overflow+inefficient Round*PowerOfTwo This code fixes an integer flow using finer bit twiddling and even makes it much more efficient to always run deterministically in O(1) time essentially and not O(k) where k=log2(input) due to the prior k iterations that then caused the overflow when result became > maxInt. While here also added some benchmarks and more test cases. Fixes #117 --- inclusion/blob_share_commitment_rules.go | 18 +- inclusion/blob_share_commitment_rules_test.go | 236 +++++++++++------- 2 files changed, 162 insertions(+), 92 deletions(-) diff --git a/inclusion/blob_share_commitment_rules.go b/inclusion/blob_share_commitment_rules.go index f51bae7..9792fa3 100644 --- a/inclusion/blob_share_commitment_rules.go +++ b/inclusion/blob_share_commitment_rules.go @@ -3,6 +3,7 @@ package inclusion import ( "fmt" "math" + "math/bits" "golang.org/x/exp/constraints" ) @@ -50,11 +51,13 @@ func RoundUpByMultipleOf(cursor, v int) int { // RoundUpPowerOfTwo returns the next power of two greater than or equal to input. func RoundUpPowerOfTwo[I constraints.Integer](input I) I { - var result I = 1 - for result < input { - result <<= 1 + if input <= 1 { + return 1 } - return result + if input&(input-1) == 0 { + return input + } + return 1 << bits.Len64(uint64(input)) } // RoundDownPowerOfTwo returns the next power of two less than or equal to input. @@ -62,11 +65,10 @@ func RoundDownPowerOfTwo[I constraints.Integer](input I) (I, error) { if input <= 0 { return 0, fmt.Errorf("input %v must be positive", input) } - roundedUp := RoundUpPowerOfTwo(input) - if roundedUp == input { - return roundedUp, nil + if input&(input-1) == 0 { + return input, nil } - return roundedUp / 2, nil + return 1 << (bits.Len64(uint64(input)) - 1), nil } // BlobMinSquareSize returns the minimum square size that can contain shareCount diff --git a/inclusion/blob_share_commitment_rules_test.go b/inclusion/blob_share_commitment_rules_test.go index 81cc115..6229584 100644 --- a/inclusion/blob_share_commitment_rules_test.go +++ b/inclusion/blob_share_commitment_rules_test.go @@ -2,6 +2,7 @@ package inclusion_test import ( "fmt" + "math" "testing" "github.com/celestiaorg/go-square/v2/inclusion" @@ -257,23 +258,29 @@ func TestRoundUpByMultipleOf(t *testing.T) { } } +type roundUpTestCase struct { + input int + want int +} + +var roundUpTestCases = []roundUpTestCase{ + {input: -1, want: 1}, + {input: 0, want: 1}, + {input: 1, want: 1}, + {input: 2, want: 2}, + {input: 4, want: 4}, + {input: 5, want: 8}, + {input: 8, want: 8}, + {input: 11, want: 16}, + {input: 511, want: 512}, + {input: math.MaxInt32 - 1, want: math.MaxInt32}, + {input: math.MaxInt32 + 1, want: math.MaxInt32 * 2}, + {input: math.MaxInt32, want: math.MaxInt32}, + {input: math.MaxInt, want: math.MaxInt}, +} + func TestRoundUpPowerOfTwo(t *testing.T) { - type testCase struct { - input int - want int - } - testCases := []testCase{ - {input: -1, want: 1}, - {input: 0, want: 1}, - {input: 1, want: 1}, - {input: 2, want: 2}, - {input: 4, want: 4}, - {input: 5, want: 8}, - {input: 8, want: 8}, - {input: 11, want: 16}, - {input: 511, want: 512}, - } - for _, tc := range testCases { + for _, tc := range roundUpTestCases { got := inclusion.RoundUpPowerOfTwo(tc.input) assert.Equal(t, tc.want, got) } @@ -326,62 +333,64 @@ func TestBlobMinSquareSize(t *testing.T) { } } +type testCase struct { + shareCount int + want int +} + +var subtreeWidthTestCases = []testCase{ + { + shareCount: 0, + want: 1, + }, + { + shareCount: 1, + want: 1, + }, + { + shareCount: 2, + want: 1, + }, + { + shareCount: defaultSubtreeRootThreshold, + want: 1, + }, + { + shareCount: defaultSubtreeRootThreshold + 1, + want: 2, + }, + { + shareCount: defaultSubtreeRootThreshold - 1, + want: 1, + }, + { + shareCount: defaultSubtreeRootThreshold * 2, + want: 2, + }, + { + shareCount: (defaultSubtreeRootThreshold * 2) + 1, + want: 4, + }, + { + shareCount: (defaultSubtreeRootThreshold * 3) - 1, + want: 4, + }, + { + shareCount: (defaultSubtreeRootThreshold * 4), + want: 4, + }, + { + shareCount: (defaultSubtreeRootThreshold * 5), + want: 8, + }, + { + shareCount: (defaultSubtreeRootThreshold * defaultMaxSquareSize) - 1, + want: 128, + }, +} + func TestSubTreeWidth(t *testing.T) { - type testCase struct { - shareCount int - want int - } - testCases := []testCase{ - { - shareCount: 0, - want: 1, - }, - { - shareCount: 1, - want: 1, - }, - { - shareCount: 2, - want: 1, - }, - { - shareCount: defaultSubtreeRootThreshold, - want: 1, - }, - { - shareCount: defaultSubtreeRootThreshold + 1, - want: 2, - }, - { - shareCount: defaultSubtreeRootThreshold - 1, - want: 1, - }, - { - shareCount: defaultSubtreeRootThreshold * 2, - want: 2, - }, - { - shareCount: (defaultSubtreeRootThreshold * 2) + 1, - want: 4, - }, - { - shareCount: (defaultSubtreeRootThreshold * 3) - 1, - want: 4, - }, - { - shareCount: (defaultSubtreeRootThreshold * 4), - want: 4, - }, - { - shareCount: (defaultSubtreeRootThreshold * 5), - want: 8, - }, - { - shareCount: (defaultSubtreeRootThreshold * defaultMaxSquareSize) - 1, - want: 128, - }, - } - for i, tc := range testCases { + for i, tc := range subtreeWidthTestCases { t.Run(fmt.Sprintf("shareCount %d", tc.shareCount), func(t *testing.T) { got := inclusion.SubTreeWidth(tc.shareCount, defaultSubtreeRootThreshold) assert.Equal(t, tc.want, got, i) @@ -389,21 +398,80 @@ func TestSubTreeWidth(t *testing.T) { } } -func TestRoundDownPowerOfTwo(t *testing.T) { - type testCase struct { - input int - want int +var sink any = nil + +func BenchmarkSubTreeWidth(b *testing.B) { + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + for _, tc := range subtreeWidthTestCases { + got := inclusion.SubTreeWidth(tc.shareCount, defaultSubtreeRootThreshold) + assert.Equal(b, tc.want, got) + sink = got + } } - testCases := []testCase{ - {input: 1, want: 1}, - {input: 2, want: 2}, - {input: 4, want: 4}, - {input: 5, want: 4}, - {input: 8, want: 8}, - {input: 11, want: 8}, - {input: 511, want: 256}, + + if sink == nil { + b.Fatal("Benchmark did not run!") } - for _, tc := range testCases { + sink = nil +} + +func BenchmarkRoundDownPowerOfTwo(b *testing.B) { + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + for _, tc := range roundDownTestCases { + got, _ := inclusion.RoundDownPowerOfTwo(tc.input) + assert.Equal(b, tc.want, got) + sink = got + } + } + + if sink == nil { + b.Fatal("Benchmark did not run!") + } + sink = nil +} + +func BenchmarkRoundUpPowerOfTwo(b *testing.B) { + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + for _, tc := range roundUpTestCases { + got := inclusion.RoundUpPowerOfTwo(tc.input) + assert.Equal(b, tc.want, got) + sink = got + } + } + + if sink == nil { + b.Fatal("Benchmark did not run!") + } + sink = nil +} + +type roundDownTestCase struct { + input int + want int +} + +var roundDownTestCases = []roundDownTestCase{ + {input: 1, want: 1}, + {input: 2, want: 2}, + {input: 4, want: 4}, + {input: 5, want: 4}, + {input: 8, want: 8}, + {input: 11, want: 8}, + {input: 511, want: 256}, + {input: math.MaxInt32 - 1, want: math.MaxInt32 / 2}, + {input: math.MaxInt32 + 1, want: math.MaxInt32}, + {input: math.MaxInt32, want: math.MaxInt32}, + {input: math.MaxInt, want: math.MaxInt}, +} + +func TestRoundDownPowerOfTwo(t *testing.T) { + for _, tc := range roundDownTestCases { got, err := inclusion.RoundDownPowerOfTwo(tc.input) require.NoError(t, err) assert.Equal(t, tc.want, got)