Skip to content

Commit

Permalink
Add State provider fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
amsanghi committed Nov 15, 2023
1 parent a8fc94b commit 4b8676d
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 74 deletions.
174 changes: 104 additions & 70 deletions staker/state_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ import (
"context"
"errors"
"fmt"
"strings"
"sync"

"github.com/ethereum/go-ethereum/accounts/abi"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"

Expand All @@ -27,26 +27,38 @@ var (
_ l2stateprovider.ExecutionProvider = (*StateManager)(nil)
)

// Defines the ABI encoding structure for submission of prefix proofs to the protocol contracts
var (
b32Arr, _ = abi.NewType("bytes32[]", "", nil)
// ProofArgs for submission to the protocol.
ProofArgs = abi.Arguments{
{Type: b32Arr, Name: "prefixExpansion"},
{Type: b32Arr, Name: "prefixProof"},
}
)

var (
ErrChainCatchingUp = errors.New("chain catching up")
)

type Opt func(*StateManager)
type BoldConfig struct {
Enable bool `koanf:"enable"`
Mode string `koanf:"mode"`
BlockChallengeLeafHeight uint64 `koanf:"block-challenge-leaf-height"`
BigStepLeafHeight uint64 `koanf:"big-step-leaf-height"`
SmallStepLeafHeight uint64 `koanf:"small-step-leaf-height"`
NumBigSteps uint64 `koanf:"num-big-steps"`
ValidatorName string `koanf:"validator-name"`
MachineLeavesCachePath string `koanf:"machine-leaves-cache-path"`
AssertionPostingIntervalSeconds uint64 `koanf:"assertion-posting-interval-seconds"`
AssertionScanningIntervalSeconds uint64 `koanf:"assertion-scanning-interval-seconds"`
AssertionConfirmingIntervalSeconds uint64 `koanf:"assertion-confirming-interval-seconds"`
EdgeTrackerWakeIntervalSeconds uint64 `koanf:"edge-tracker-wake-interval-seconds"`
}

func DisableCache() Opt {
return func(sm *StateManager) {
sm.historyCache = nil
}
var DefaultBoldConfig = BoldConfig{
Enable: false,
Mode: "make-mode",
BlockChallengeLeafHeight: 1 << 5,
BigStepLeafHeight: 1 << 5,
SmallStepLeafHeight: 1 << 7,
NumBigSteps: 5,
ValidatorName: "default-validator",
MachineLeavesCachePath: "/tmp/machine-leaves-cache",
AssertionPostingIntervalSeconds: 30,
AssertionScanningIntervalSeconds: 30,
AssertionConfirmingIntervalSeconds: 60,
EdgeTrackerWakeIntervalSeconds: 1,
}

type StateManager struct {
Expand All @@ -62,7 +74,6 @@ func NewStateManager(
cacheBaseDir string,
challengeLeafHeights []l2stateprovider.Height,
validatorName string,
opts ...Opt,
) (*StateManager, error) {
historyCache := challengecache.New(cacheBaseDir)
sm := &StateManager{
Expand All @@ -71,13 +82,10 @@ func NewStateManager(
challengeLeafHeights: challengeLeafHeights,
validatorName: validatorName,
}
for _, o := range opts {
o(sm)
}
return sm, nil
}

// ExecutionStateMsgCount If the state manager locally has this validated execution state.
// AgreesWithExecutionState If the state manager locally has this validated execution state.
// Returns ErrNoExecutionState if not found, or ErrChainCatchingUp if not yet
// validated / syncing.
func (s *StateManager) AgreesWithExecutionState(ctx context.Context, state *protocol.ExecutionState) error {
Expand Down Expand Up @@ -133,6 +141,9 @@ func (s *StateManager) ExecutionStateAfterBatchCount(ctx context.Context, batchC
batchIndex := batchCount - 1
messageCount, err := s.validator.inboxTracker.GetBatchMessageCount(batchIndex)
if err != nil {
if strings.Contains(err.Error(), "not found") {
return nil, fmt.Errorf("%w: batch count %d", l2stateprovider.ErrChainCatchingUp, batchCount)
}
return nil, err
}
globalState, err := s.findGlobalStateFromMessageCountAndBatch(messageCount, l2stateprovider.Batch(batchIndex))
Expand All @@ -158,76 +169,88 @@ func (s *StateManager) StatesInBatchRange(
fromBatch,
toBatch l2stateprovider.Batch,
) ([]common.Hash, []validator.GoGlobalState, error) {
// Check integrity of the arguments.
if fromBatch > toBatch {
return nil, nil, fmt.Errorf("from batch %v is greater than to batch %v", fromBatch, toBatch)
// Check the integrity of the arguments.
if fromBatch >= toBatch {
return nil, nil, fmt.Errorf("from batch %v cannot be greater than or equal to batch %v", fromBatch, toBatch)
}
if fromHeight > toHeight {
return nil, nil, fmt.Errorf("from height %v is greater than to height %v", fromHeight, toHeight)
return nil, nil, fmt.Errorf("from height %v cannot be greater than to height %v", fromHeight, toHeight)
}
// Compute the total desired hashes from this request.
totalDesiredHashes := (toHeight - fromHeight) + 1

// The last message's batch count.
// Get the fromBatch's message count.
prevBatchMsgCount, err := s.validator.inboxTracker.GetBatchMessageCount(uint64(fromBatch) - 1)
if err != nil {
return nil, nil, err
}
gs, err := s.findGlobalStateFromMessageCountAndBatch(prevBatchMsgCount, fromBatch-1)
executionResult, err := s.validator.streamer.ResultAtCount(prevBatchMsgCount)
if err != nil {
return nil, nil, err
}
if gs.PosInBatch == 0 {
return nil, nil, errors.New("final state of batch cannot be at position zero")
}
// The start state root of our history commitment starts at `batch: fromBatch, pos: 0` using the state
// from the last batch.
gs.Batch += 1
gs.PosInBatch = 0
stateRoots := []common.Hash{
crypto.Keccak256Hash([]byte("Machine finished:"), gs.Hash().Bytes()),
}
globalStates := []validator.GoGlobalState{gs}

// Check if there are enough messages in the range to satisfy our request.
totalDesiredHashes := (toHeight - fromHeight) + 1

// We can return early if all we want is one hash.
if totalDesiredHashes == 1 && fromHeight == 0 && toHeight == 0 {
return stateRoots, globalStates, nil
startState := validator.GoGlobalState{
BlockHash: executionResult.BlockHash,
SendRoot: executionResult.SendRoot,
Batch: uint64(fromBatch),
PosInBatch: 0,
}
machineHashes := []common.Hash{machineHash(startState)}
states := []validator.GoGlobalState{startState}

for batch := fromBatch; batch < toBatch; batch++ {
msgCount, err := s.validator.inboxTracker.GetBatchMessageCount(uint64(batch))
batchMessageCount, err := s.validator.inboxTracker.GetBatchMessageCount(uint64(batch))
if err != nil {
return nil, nil, err
}
var lastGlobalState validator.GoGlobalState
messagesInBatch := batchMessageCount - prevBatchMsgCount

msgsInBatch := msgCount - prevBatchMsgCount
for i := uint64(1); i <= uint64(msgsInBatch); i++ {
// Obtain the states for each message in the batch.
for i := uint64(0); i < uint64(messagesInBatch); i++ {
msgIndex := uint64(prevBatchMsgCount) + i
gs, err := s.findGlobalStateFromMessageCountAndBatch(arbutil.MessageIndex(msgIndex), batch)
messageCount := msgIndex + 1
executionResult, err := s.validator.streamer.ResultAtCount(arbutil.MessageIndex(messageCount))
if err != nil {
return nil, nil, err
}
globalStates = append(globalStates, gs)
stateRoots = append(stateRoots,
crypto.Keccak256Hash([]byte("Machine finished:"), gs.Hash().Bytes()),
)
lastGlobalState = gs
// If the position in batch is equal to the number of messages in the batch,
// we do not include this state, instead, we break and include the state
// that fully consumes the batch.
if i+1 == uint64(messagesInBatch) {
break
}
state := validator.GoGlobalState{
BlockHash: executionResult.BlockHash,
SendRoot: executionResult.SendRoot,
Batch: uint64(batch),
PosInBatch: i + 1,
}
states = append(states, state)
machineHashes = append(machineHashes, machineHash(state))
}
prevBatchMsgCount = msgCount
lastGlobalState.Batch += 1
lastGlobalState.PosInBatch = 0
stateRoots = append(stateRoots,
crypto.Keccak256Hash([]byte("Machine finished:"), lastGlobalState.Hash().Bytes()),
)
globalStates = append(globalStates, lastGlobalState)
}

for uint64(len(stateRoots)) < uint64(totalDesiredHashes) {
stateRoots = append(stateRoots, stateRoots[len(stateRoots)-1])
// Fully consume the batch.
executionResult, err := s.validator.streamer.ResultAtCount(batchMessageCount)
if err != nil {
return nil, nil, err
}
state := validator.GoGlobalState{
BlockHash: executionResult.BlockHash,
SendRoot: executionResult.SendRoot,
Batch: uint64(batch) + 1,
PosInBatch: 0,
}
states = append(states, state)
machineHashes = append(machineHashes, machineHash(state))
prevBatchMsgCount = batchMessageCount
}
return stateRoots[fromHeight : toHeight+1], globalStates[fromHeight : toHeight+1], nil
for uint64(len(machineHashes)) < uint64(totalDesiredHashes) {
machineHashes = append(machineHashes, machineHashes[len(machineHashes)-1])
}
return machineHashes[fromHeight : toHeight+1], states, nil
}

func machineHash(gs validator.GoGlobalState) common.Hash {
return crypto.Keccak256Hash([]byte("Machine finished:"), gs.Hash().Bytes())
}

func (s *StateManager) findGlobalStateFromMessageCountAndBatch(count arbutil.MessageIndex, batchIndex l2stateprovider.Batch) (validator.GoGlobalState, error) {
Expand Down Expand Up @@ -284,9 +307,14 @@ func (s *StateManager) CollectMachineHashes(
) ([]common.Hash, error) {
s.Lock()
defer s.Unlock()
prevBatchMsgCount, err := s.validator.inboxTracker.GetBatchMessageCount(uint64(cfg.FromBatch - 1))
if err != nil {
return nil, fmt.Errorf("could not get batch message count at %d: %w", cfg.FromBatch, err)
}
messageNum := prevBatchMsgCount + arbutil.MessageIndex(cfg.BlockChallengeHeight)
cacheKey := &challengecache.Key{
WavmModuleRoot: cfg.WasmModuleRoot,
MessageHeight: protocol.Height(cfg.MessageNumber),
MessageHeight: protocol.Height(messageNum),
StepHeights: cfg.StepHeights,
}
if s.historyCache != nil {
Expand All @@ -298,7 +326,7 @@ func (s *StateManager) CollectMachineHashes(
return nil, err
}
}
entry, err := s.validator.CreateReadyValidationEntry(ctx, arbutil.MessageIndex(cfg.MessageNumber))
entry, err := s.validator.CreateReadyValidationEntry(ctx, messageNum)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -330,10 +358,16 @@ func (s *StateManager) CollectMachineHashes(
func (s *StateManager) CollectProof(
ctx context.Context,
wasmModuleRoot common.Hash,
messageNumber l2stateprovider.Height,
fromBatch l2stateprovider.Batch,
blockChallengeHeight l2stateprovider.Height,
machineIndex l2stateprovider.OpcodeIndex,
) ([]byte, error) {
entry, err := s.validator.CreateReadyValidationEntry(ctx, arbutil.MessageIndex(messageNumber))
prevBatchMsgCount, err := s.validator.inboxTracker.GetBatchMessageCount(uint64(fromBatch) - 1)
if err != nil {
return nil, err
}
messageNum := prevBatchMsgCount + arbutil.MessageIndex(blockChallengeHeight)
entry, err := s.validator.CreateReadyValidationEntry(ctx, messageNum)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion system_tests/assertion_on_large_number_of_batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func TestAssertionOnLargeNumberOfBatch(t *testing.T) {
manager, err := staker.NewStateManager(stateless, t.TempDir(), nil)
Require(t, err)

poster := assertions.NewPoster(
poster := assertions.NewManager(
assertionChain,
manager,
"test",
Expand Down
4 changes: 2 additions & 2 deletions system_tests/bold_challenge_protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func TestBoldProtocol(t *testing.T) {
)
Require(t, err)

poster, err := assertions.NewPoster(
poster, err := assertions.NewManager(
assertionChain,
stateManager,
"good",
Expand Down Expand Up @@ -211,7 +211,7 @@ func TestBoldProtocol(t *testing.T) {
)
Require(t, err)

posterB, err := assertions.NewPoster(
posterB, err := assertions.NewManager(
chainB,
stateManagerB,
"evil",
Expand Down
2 changes: 1 addition & 1 deletion validator/server_arb/execution_run.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func (e *executionRun) GetLeavesWithStepSize(machineStartIndex, stepSize, numDes
for uint64(len(stateRoots)) < numDesiredLeaves {
stateRoots = append(stateRoots, stateRoots[len(stateRoots)-1])
}
return stateRoots, nil
return stateRoots[:numDesiredLeaves], nil
})
}

Expand Down

0 comments on commit 4b8676d

Please sign in to comment.