Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Mmap to handle large slices of leaves #1931

Merged
merged 4 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions staker/challenge-cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ import (

protocol "github.com/OffchainLabs/bold/chain-abstraction"
l2stateprovider "github.com/OffchainLabs/bold/layer2-state-provider"
"github.com/OffchainLabs/bold/mmap"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/log"
)
Expand All @@ -58,8 +60,8 @@ var (

// HistoryCommitmentCacher can retrieve history commitment state roots given lookup keys.
type HistoryCommitmentCacher interface {
Get(lookup *Key, numToRead uint64) ([]common.Hash, error)
Put(lookup *Key, stateRoots []common.Hash) error
Get(lookup *Key, numToRead uint64) (mmap.Mmap, error)
Put(lookup *Key, stateRoots mmap.Mmap) error
}

// Cache for history commitments on disk.
Expand Down Expand Up @@ -121,7 +123,7 @@ type Key struct {
func (c *Cache) Get(
lookup *Key,
numToRead uint64,
) ([]common.Hash, error) {
) (mmap.Mmap, error) {
fName, err := determineFilePath(c.baseDir, lookup)
if err != nil {
return nil, err
Expand All @@ -147,7 +149,7 @@ func (c *Cache) Get(
// State roots are saved as files in a directory hierarchy for the cache.
// This function first creates a temporary file, writes the state roots to it, and then renames the file
// to the final directory to ensure atomic writes.
func (c *Cache) Put(lookup *Key, stateRoots []common.Hash) error {
func (c *Cache) Put(lookup *Key, stateRoots mmap.Mmap) error {
// We should error if trying to put 0 state roots to disk.
if len(stateRoots) == 0 {
return ErrNoStateRoots
Expand Down Expand Up @@ -189,11 +191,15 @@ func (c *Cache) Put(lookup *Key, stateRoots []common.Hash) error {
}

// Reads 32 bytes at a time from a reader up to a specified height. If none, then read all.
func readStateRoots(r io.Reader, numToRead uint64) ([]common.Hash, error) {
func readStateRoots(r io.Reader, numToRead uint64) (mmap.Mmap, error) {
br := bufio.NewReader(r)
stateRoots := make([]common.Hash, 0)
stateRootsMmap, err := mmap.NewMmap(int(numToRead))
if err != nil {
return nil, err
}
buf := make([]byte, 0, 32)
for totalRead := uint64(0); totalRead < numToRead; totalRead++ {
var totalRead uint64
for totalRead = uint64(0); totalRead < numToRead; totalRead++ {
n, err := br.Read(buf[:cap(buf)])
if err != nil {
// If we try to read but reach EOF, we break out of the loop.
Expand All @@ -206,30 +212,30 @@ func readStateRoots(r io.Reader, numToRead uint64) ([]common.Hash, error) {
if n != 32 {
return nil, fmt.Errorf("expected to read 32 bytes, got %d bytes", n)
}
stateRoots = append(stateRoots, common.BytesToHash(buf))
stateRootsMmap.Set(int(totalRead), common.BytesToHash(buf))
}
if protocol.Height(numToRead) > protocol.Height(len(stateRoots)) {
if protocol.Height(numToRead) > protocol.Height(totalRead) {
return nil, fmt.Errorf(
"wanted to read %d roots, but only read %d state roots",
numToRead,
len(stateRoots),
totalRead,
)
}
return stateRoots, nil
return stateRootsMmap, nil
}

func writeStateRoots(w io.Writer, stateRoots []common.Hash) error {
for i, rt := range stateRoots {
n, err := w.Write(rt[:])
func writeStateRoots(w io.Writer, stateRoots mmap.Mmap) error {
for i := 0; i < stateRoots.Length(); i++ {
n, err := w.Write(stateRoots.Get(i).Bytes())
if err != nil {
return err
}
if n != len(rt) {
if n != len(stateRoots.Get(i)) {
return fmt.Errorf(
"for state root %d, wrote %d bytes, expected to write %d bytes",
i,
n,
len(rt),
len(stateRoots.Get(i)),
)
}
}
Expand Down
51 changes: 33 additions & 18 deletions staker/challenge-cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"testing"

l2stateprovider "github.com/OffchainLabs/bold/layer2-state-provider"
"github.com/OffchainLabs/bold/mmap"

"github.com/ethereum/go-ethereum/common"
)

Expand Down Expand Up @@ -43,15 +45,18 @@ func TestCache(t *testing.T) {
}
})
t.Run("Putting empty root fails", func(t *testing.T) {
if err := cache.Put(key, []common.Hash{}); !errors.Is(err, ErrNoStateRoots) {
if err := cache.Put(key, mmap.Mmap{}); !errors.Is(err, ErrNoStateRoots) {
t.Fatalf("Unexpected error: %v", err)
}
})
want := []common.Hash{
common.BytesToHash([]byte("foo")),
common.BytesToHash([]byte("bar")),
common.BytesToHash([]byte("baz")),
want, err := mmap.NewMmap(3)
want.Set(0, common.BytesToHash([]byte("foo")))
want.Set(1, common.BytesToHash([]byte("bar")))
want.Set(2, common.BytesToHash([]byte("baz")))
if err != nil {
t.Fatal(err)
}
defer want.Free()
err = cache.Put(key, want)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -92,7 +97,7 @@ func TestReadWriteStateRoots(t *testing.T) {
if len(roots) == 0 {
t.Fatal("Got no roots")
}
if roots[0] != want {
if roots.Get(0) != want {
t.Fatalf("Wrong root. Expected %#x, got %#x", want, roots[0])
}
})
Expand All @@ -108,24 +113,30 @@ func TestReadWriteStateRoots(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if len(roots) != 2 {
if roots.Length() != 2 {
t.Fatalf("Expected two roots, got %d", len(roots))
}
if roots[0] != foo {
if roots.Get(0) != foo {
t.Fatalf("Wrong root. Expected %#x, got %#x", foo, roots[0])
}
if roots[1] != bar {
if roots.Get(1) != bar {
t.Fatalf("Wrong root. Expected %#x, got %#x", bar, roots[1])
}
})
t.Run("Fails to write enough data to writer", func(t *testing.T) {
m := &mockWriter{wantErr: true}
err := writeStateRoots(m, []common.Hash{common.BytesToHash([]byte("foo"))})
stateRoots, err := mmap.NewMmap(1)
if err != nil {
t.Fatal(err)
}
defer stateRoots.Free()
stateRoots.Set(0, common.BytesToHash([]byte("foo")))
err = writeStateRoots(m, stateRoots)
if err == nil {
t.Fatal("Wanted error")
}
m = &mockWriter{wantErr: false, numWritten: 16}
err = writeStateRoots(m, []common.Hash{common.BytesToHash([]byte("foo"))})
err = writeStateRoots(m, stateRoots)
if err == nil {
t.Fatal("Wanted error")
}
Expand Down Expand Up @@ -224,11 +235,11 @@ func Test_readStateRoots(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if len(want) != len(got) {
if len(want) != got.Length() {
t.Fatal("Wrong number of roots")
}
for i, rt := range got {
if rt != want[i] {
for i := 0; i < got.Length(); i++ {
if got.Get(i) != want[i] {
t.Fatal("Wrong root")
}
}
Expand Down Expand Up @@ -303,11 +314,15 @@ func BenchmarkCache_Read_32Mb(b *testing.B) {
StepHeights: []l2stateprovider.Height{l2stateprovider.Height(0)},
}
numRoots := 1 << 20
roots := make([]common.Hash, numRoots)
for i := range roots {
roots[i] = common.BytesToHash([]byte(fmt.Sprintf("%d", i)))
rootsMmap, err := mmap.NewMmap(numRoots)
if err != nil {
b.Fatal(err)
}
defer rootsMmap.Free()
for i := 0; i < numRoots; i++ {
rootsMmap.Set(i, common.BytesToHash([]byte(fmt.Sprintf("%d", i))))
}
if err := cache.Put(key, roots); err != nil {
if err := cache.Put(key, rootsMmap); err != nil {
b.Fatal(err)
}
b.StartTimer()
Expand Down
50 changes: 30 additions & 20 deletions staker/state_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/OffchainLabs/bold/challenge-manager/types"
"github.com/OffchainLabs/bold/containers/option"
l2stateprovider "github.com/OffchainLabs/bold/layer2-state-provider"
"github.com/OffchainLabs/bold/mmap"

"github.com/offchainlabs/nitro/arbutil"
challengecache "github.com/offchainlabs/nitro/staker/challenge-cache"
Expand Down Expand Up @@ -186,39 +187,45 @@ func (s *StateManager) StatesInBatchRange(
toHeight l2stateprovider.Height,
fromBatch,
toBatch l2stateprovider.Batch,
) ([]common.Hash, []validator.GoGlobalState, error) {
) (mmap.Mmap, error) {
// Check the integrity of the arguments.
if fromBatch >= toBatch {
return nil, nil, fmt.Errorf("from batch %v cannot be greater than or equal to batch %v", fromBatch, toBatch)
return nil, fmt.Errorf("from batch %v cannot be greater than or equal to batch %v", fromBatch, toBatch)
}
if fromHeight > toHeight {
return nil, nil, fmt.Errorf("from height %v cannot be greater than to height %v", fromHeight, toHeight)
return nil, fmt.Errorf("from height %v cannot be greater than to height %v", fromHeight, toHeight)
}
// Compute the total desired hashes from this request.
totalDesiredHashes := (toHeight - fromHeight) + 1

// Get the fromBatch's message count.
prevBatchMsgCount, err := s.validator.inboxTracker.GetBatchMessageCount(uint64(fromBatch) - 1)
if err != nil {
return nil, nil, err
return nil, err
}
executionResult, err := s.validator.streamer.ResultAtCount(prevBatchMsgCount)
if err != nil {
return nil, nil, err
return nil, err
}
startState := validator.GoGlobalState{
BlockHash: executionResult.BlockHash,
SendRoot: executionResult.SendRoot,
Batch: uint64(fromBatch),
PosInBatch: 0,
}
machineHashes := []common.Hash{machineHash(startState)}
states := []validator.GoGlobalState{startState}
machineHashesMmap, err := mmap.NewMmap(int(totalDesiredHashes))
numStateRoots := 0
if err != nil {
return nil, err
}
machineHashesMmap.Set(numStateRoots, machineHash(startState))
numStateRoots++

for batch := fromBatch; batch < toBatch; batch++ {
batchMessageCount, err := s.validator.inboxTracker.GetBatchMessageCount(uint64(batch))
if err != nil {
return nil, nil, err
machineHashesMmap.Free()
return nil, err
}
messagesInBatch := batchMessageCount - prevBatchMsgCount

Expand All @@ -228,7 +235,8 @@ func (s *StateManager) StatesInBatchRange(
messageCount := msgIndex + 1
executionResult, err := s.validator.streamer.ResultAtCount(arbutil.MessageIndex(messageCount))
if err != nil {
return nil, nil, err
machineHashesMmap.Free()
return nil, err
}
// If the position in batch is equal to the number of messages in the batch,
// we do not include this state, instead, we break and include the state
Expand All @@ -242,29 +250,31 @@ func (s *StateManager) StatesInBatchRange(
Batch: uint64(batch),
PosInBatch: i + 1,
}
states = append(states, state)
machineHashes = append(machineHashes, machineHash(state))
machineHashesMmap.Set(numStateRoots, machineHash(state))
numStateRoots++
}

// Fully consume the batch.
executionResult, err := s.validator.streamer.ResultAtCount(batchMessageCount)
if err != nil {
return nil, nil, err
machineHashesMmap.Free()
return nil, err
}
state := validator.GoGlobalState{
BlockHash: executionResult.BlockHash,
SendRoot: executionResult.SendRoot,
Batch: uint64(batch) + 1,
PosInBatch: 0,
}
states = append(states, state)
machineHashes = append(machineHashes, machineHash(state))
machineHashesMmap.Set(numStateRoots, machineHash(state))
numStateRoots++
prevBatchMsgCount = batchMessageCount
}
for uint64(len(machineHashes)) < uint64(totalDesiredHashes) {
machineHashes = append(machineHashes, machineHashes[len(machineHashes)-1])
lastMachineHashes := machineHashesMmap.Get(numStateRoots - 1)
for i := numStateRoots; i < int(totalDesiredHashes); i++ {
machineHashesMmap.Set(i, lastMachineHashes)
}
return machineHashes[fromHeight : toHeight+1], states, nil
return machineHashesMmap.SubMmap(int(fromHeight), int(toHeight+1)), nil
}

func machineHash(gs validator.GoGlobalState) common.Hash {
Expand Down Expand Up @@ -304,15 +314,15 @@ func (s *StateManager) L2MessageStatesUpTo(
toHeight option.Option[l2stateprovider.Height],
fromBatch,
toBatch l2stateprovider.Batch,
) ([]common.Hash, error) {
) (mmap.Mmap, error) {
var to l2stateprovider.Height
if !toHeight.IsNone() {
to = toHeight.Unwrap()
} else {
blockChallengeLeafHeight := s.challengeLeafHeights[0]
to = blockChallengeLeafHeight
}
items, _, err := s.StatesInBatchRange(fromHeight, to, fromBatch, toBatch)
items, err := s.StatesInBatchRange(fromHeight, to, fromBatch, toBatch)
if err != nil {
return nil, err
}
Expand All @@ -322,7 +332,7 @@ func (s *StateManager) L2MessageStatesUpTo(
// CollectMachineHashes Collects a list of machine hashes at a message number based on some configuration parameters.
func (s *StateManager) CollectMachineHashes(
ctx context.Context, cfg *l2stateprovider.HashCollectorConfig,
) ([]common.Hash, error) {
) (mmap.Mmap, error) {
s.Lock()
defer s.Unlock()
prevBatchMsgCount, err := s.validator.inboxTracker.GetBatchMessageCount(uint64(cfg.FromBatch - 1))
Expand Down
13 changes: 1 addition & 12 deletions system_tests/state_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,23 +157,12 @@
toBatch := l2stateprovider.Batch(3)
fromHeight := l2stateprovider.Height(0)
toHeight := l2stateprovider.Height(14)
stateRoots, states, err := stateManager.StatesInBatchRange(fromHeight, toHeight, fromBatch, toBatch)
stateRoots, err := stateManager.StatesInBatchRange(fromHeight, toHeight, fromBatch, toBatch)
Require(t, err)

if len(stateRoots) != 15 {
Fatal(t, "wrong number of state roots")
}
if len(states) == 0 {
Fatal(t, "no states returned")
}
firstState := states[0]
if firstState.Batch != 1 && firstState.PosInBatch != 0 {
Fatal(t, "wrong first state")
}
lastState := states[len(states)-1]
if lastState.Batch != 1 && lastState.PosInBatch != 0 {
Fatal(t, "wrong last state")
}
})
t.Run("AgreesWithExecutionState", func(t *testing.T) {
// Non-zero position in batch shoould fail.
Expand Down Expand Up @@ -314,7 +303,7 @@
l2stateprovider.Height(smallStepChallengeLeafHeight),
},
"good",
staker.DisableCache(),

Check failure on line 306 in system_tests/state_provider_test.go

View workflow job for this annotation

GitHub Actions / Go Tests (challenge)

undefined: staker.DisableCache

Check failure on line 306 in system_tests/state_provider_test.go

View workflow job for this annotation

GitHub Actions / Go Tests (challenge)

undefined: staker.DisableCache
)
Require(t, err)
return l2node, l1info, l2info, l1stack, l1client, stateManager
Expand Down
Loading
Loading