From 5aa9b466eff38c63a60eacb00c7e7f0d8bf9706b Mon Sep 17 00:00:00 2001 From: Krish Date: Mon, 9 Sep 2024 10:53:29 +0800 Subject: [PATCH] fix ut --- core/blockchain.go | 104 +++++++++++++++++++------------------- miner/fix_manager.go | 51 ++++++++++++------- miner/miner_test.go | 4 ++ miner/payload_building.go | 40 +++++++++++---- miner/worker_test.go | 6 ++- 5 files changed, 122 insertions(+), 83 deletions(-) diff --git a/core/blockchain.go b/core/blockchain.go index a40364168d..a4d6960f1d 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -2204,8 +2204,8 @@ func (bc *BlockChain) insertSideChain(block *types.Block, it *insertIterator) (i return 0, nil } -func (bc *BlockChain) RecoverAncestors(block *types.Block) (common.Hash, error) { - return bc.recoverAncestorsWithSethead(block) +func (bc *BlockChain) RecoverStateAndSetHead(block *types.Block) (common.Hash, error) { + return bc.recoverStateAndSetHead(block) } // recoverAncestors finds the closest ancestor with available state and re-execute @@ -2259,57 +2259,6 @@ func (bc *BlockChain) recoverAncestors(block *types.Block) (common.Hash, error) return block.Hash(), nil } -// recoverAncestors finds the closest ancestor with available state and re-execute -// all the ancestor blocks since that. -// recoverAncestors is only used post-merge. -// We return the hash of the latest block that we could correctly validate. -func (bc *BlockChain) recoverAncestorsWithSethead(block *types.Block) (common.Hash, error) { - // Gather all the sidechain hashes (full blocks may be memory heavy) - var ( - hashes []common.Hash - numbers []uint64 - parent = block - ) - for parent != nil && !bc.HasState(parent.Root()) { - if bc.stateRecoverable(parent.Root()) { - if err := bc.triedb.Recover(parent.Root()); err != nil { - return common.Hash{}, err - } - break - } - hashes = append(hashes, parent.Hash()) - numbers = append(numbers, parent.NumberU64()) - parent = bc.GetBlock(parent.ParentHash(), parent.NumberU64()-1) - - // If the chain is terminating, stop iteration - if bc.insertStopped() { - log.Debug("Abort during blocks iteration") - return common.Hash{}, errInsertionInterrupted - } - } - if parent == nil { - return common.Hash{}, errors.New("missing parent") - } - // Import all the pruned blocks to make the state available - for i := len(hashes) - 1; i >= 0; i-- { - // If the chain is terminating, stop processing blocks - if bc.insertStopped() { - log.Debug("Abort during blocks processing") - return common.Hash{}, errInsertionInterrupted - } - var b *types.Block - if i == 0 { - b = block - } else { - b = bc.GetBlock(hashes[i], numbers[i]) - } - if _, err := bc.insertChain(types.Blocks{b}, true); err != nil { - return b.ParentHash(), err - } - } - return block.Hash(), nil -} - // collectLogs collects the logs that were generated or removed during // the processing of a block. These logs are later announced as deleted or reborn. func (bc *BlockChain) collectLogs(b *types.Block, removed bool) []*types.Log { @@ -2764,6 +2713,55 @@ func (bc *BlockChain) InsertHeaderChain(chain []*types.Header) (int, error) { return 0, err } +// recoverStateAndSetHead attempts to recover the state of the blockchain by re-importing +// missing blocks and advancing the chain head. It ensures the state is available +// for the given block and its ancestors before updating the head. +func (bc *BlockChain) recoverStateAndSetHead(block *types.Block) (common.Hash, error) { + var ( + hashes []common.Hash + numbers []uint64 + parent = block + ) + for parent != nil && !bc.HasState(parent.Root()) { + if bc.stateRecoverable(parent.Root()) { + if err := bc.triedb.Recover(parent.Root()); err != nil { + return common.Hash{}, err + } + break + } + hashes = append(hashes, parent.Hash()) + numbers = append(numbers, parent.NumberU64()) + parent = bc.GetBlock(parent.ParentHash(), parent.NumberU64()-1) + + // If the chain is terminating, stop iteration + if bc.insertStopped() { + log.Debug("Abort during blocks iteration") + return common.Hash{}, errInsertionInterrupted + } + } + if parent == nil { + return common.Hash{}, errors.New("missing parent") + } + // Import all the pruned blocks to make the state available + for i := len(hashes) - 1; i >= 0; i-- { + // If the chain is terminating, stop processing blocks + if bc.insertStopped() { + log.Debug("Abort during blocks processing") + return common.Hash{}, errInsertionInterrupted + } + var b *types.Block + if i == 0 { + b = block + } else { + b = bc.GetBlock(hashes[i], numbers[i]) + } + if _, err := bc.insertChain(types.Blocks{b}, true); err != nil { + return b.ParentHash(), err + } + } + return block.Hash(), nil +} + // SetBlockValidatorAndProcessorForTesting sets the current validator and processor. // This method can be used to force an invalid blockchain to be verified for tests. // This method is unsafe and should only be used before block import starts. diff --git a/miner/fix_manager.go b/miner/fix_manager.go index dbcd237167..933a7c040e 100644 --- a/miner/fix_manager.go +++ b/miner/fix_manager.go @@ -21,6 +21,12 @@ type FixManager struct { } +// FixResult holds the result of the fix operation +type FixResult struct { + Success bool + Err error +} + // NewFixManager initializes a FixManager with required dependencies func NewFixManager(downloader *downloader.Downloader) *FixManager { return &FixManager{ @@ -35,8 +41,8 @@ func (fm *FixManager) StartFix(worker *worker, id engine.PayloadID, parentHash c if !fm.isFixInProgress { fm.isFixInProgress = true - fixChan := make(chan struct{}) - fm.fixChannels.Store(id, fixChan) + resultChan := make(chan FixResult, 1) // Channel to capture fix result (success or error) + fm.fixChannels.Store(id, resultChan) go func() { defer func() { @@ -44,12 +50,12 @@ func (fm *FixManager) StartFix(worker *worker, id engine.PayloadID, parentHash c fm.isFixInProgress = false fm.mutex.Unlock() - // Notify listeners that the fix is complete if ch, ok := fm.fixChannels.Load(id); ok { - close(ch.(chan struct{})) + resultChan := ch.(chan FixResult) + close(resultChan) } }() - worker.fix(parentHash) // Execute the fix logic + worker.fix(parentHash, resultChan) // processing fix logic }() } } @@ -61,23 +67,34 @@ func (fm *FixManager) StartFix(worker *worker, id engine.PayloadID, parentHash c func (fm *FixManager) ListenFixCompletion(worker *worker, id engine.PayloadID, payload *Payload, args *BuildPayloadArgs) { ch, exists := fm.fixChannels.Load(id) if !exists { - log.Info("payload is not fixing or has been completed") + log.Info("Payload is not fixing or has been completed") return } // Check if a listener goroutine has already been started if _, listenerExists := fm.listenerStarted.LoadOrStore(id, true); listenerExists { log.Info("Listener already started for payload", "payload", id) - return // If listener goroutine already exists, return immediately + return } go func() { - log.Info("start waiting") - <-ch.(chan struct{}) // Wait for the fix to complete - log.Info("Fix completed, retrying payload update", "id", id) - worker.retryPayloadUpdate(args, payload) - fm.fixChannels.Delete(id) // Remove the id from fixChannels - fm.listenerStarted.Delete(id) // Remove the listener flag for this id + log.Info("Start waiting for fix completion") + result := <-ch.(chan FixResult) // Wait for the fix result + + // Check the result and decide whether to retry the payload update + if result.Success { + if err := worker.retryPayloadUpdate(args, payload); err != nil { + log.Error("Failed to retry payload update after fix", "id", id, "err", err) + } else { + log.Info("Payload update after fix succeeded", "id", id) + } + } else { + log.Error("Fix failed, skipping payload update", "id", id, "err", result.Err) + } + + // Clean up the fix state + fm.fixChannels.Delete(id) + fm.listenerStarted.Delete(id) }() } @@ -90,13 +107,13 @@ func (fm *FixManager) RecoverFromLocal(w *worker, blockHash common.Hash) error { return fmt.Errorf("block not found in local chain") } - log.Info("Fixing data for block", "blocknumber", block.NumberU64()) - latestValid, err := w.chain.RecoverAncestors(block) + log.Info("Fixing data for block", "block number", block.NumberU64()) + latestValid, err := w.chain.RecoverStateAndSetHead(block) if err != nil { - return fmt.Errorf("failed to recover ancestors: %v", err) + return fmt.Errorf("failed to recover state: %v", err) } - log.Info("Recovered ancestors up to block", "latestValid", latestValid) + log.Info("Recovered states up to block", "latestValid", latestValid) return nil } diff --git a/miner/miner_test.go b/miner/miner_test.go index 411d6026ce..24c626d12b 100644 --- a/miner/miner_test.go +++ b/miner/miner_test.go @@ -59,6 +59,10 @@ func (m *mockBackend) TxPool() *txpool.TxPool { return m.txPool } +func (m *mockBackend) Downloader() *downloader.Downloader { + return nil +} + func (m *mockBackend) StateAtBlock(block *types.Block, reexec uint64, base *state.StateDB, checkLive bool, preferDisk bool) (statedb *state.StateDB, err error) { return nil, errors.New("not supported") } diff --git a/miner/payload_building.go b/miner/payload_building.go index 7aff5e19c5..22d66117a3 100644 --- a/miner/payload_building.go +++ b/miner/payload_building.go @@ -20,6 +20,7 @@ import ( "crypto/sha256" "encoding/binary" "errors" + "fmt" "math/big" "strings" "sync" @@ -276,22 +277,32 @@ func (payload *Payload) stopBuilding() { // missing the block), it attempts to retrieve the block header from peers and triggers // // blockHash: The hash of the latest block that needs to be recovered and fixed. -func (w *worker) fix(blockHash common.Hash) { +func (w *worker) fix(blockHash common.Hash, resultChan chan FixResult) { log.Info("Fix operation started") + // Try to recover from local data err := w.fixManager.RecoverFromLocal(w, blockHash) if err != nil { - log.Warn("Local recovery failed, trying to recover from peers", "err", err) - - err = w.fixManager.RecoverFromPeer(blockHash) - if err != nil { - log.Error("Failed to recover from peers", "err", err) + // Only proceed to peer recovery if the error is "block not found in local chain" + if strings.Contains(err.Error(), "block not found") { + log.Warn("Local recovery failed, trying to recover from peers", "err", err) + + // Try to recover from peers + err = w.fixManager.RecoverFromPeer(blockHash) + if err != nil { + log.Error("Failed to recover from peers", "err", err) + resultChan <- FixResult{Success: false, Err: err} + return + } + } else { + log.Error("Failed to recover from local data", "err", err) + resultChan <- FixResult{Success: false, Err: err} return } } - log.Info("Fix operation completed") - + log.Info("Fix operation completed successfully") + resultChan <- FixResult{Success: true, Err: nil} } // buildPayload builds the payload according to the provided parameters. @@ -439,7 +450,7 @@ func (w *worker) buildPayload(args *BuildPayloadArgs) (*Payload, error) { // This function reconstructs the block using the provided BuildPayloadArgs and // attempts to update the payload in the system. It performs validation of the // block parameters and updates the payload if the block is successfully built. -func (w *worker) retryPayloadUpdate(args *BuildPayloadArgs, payload *Payload) { +func (w *worker) retryPayloadUpdate(args *BuildPayloadArgs, payload *Payload) error { fullParams := &generateParams{ timestamp: args.Timestamp, forceTime: true, @@ -457,7 +468,8 @@ func (w *worker) retryPayloadUpdate(args *BuildPayloadArgs, payload *Payload) { // validate the BuildPayloadArgs here. _, err := w.validateParams(fullParams) if err != nil { - return + log.Error("Failed to validate payload parameters", "id", payload.id, "err", err) + return fmt.Errorf("failed to validate payload parameters: %w", err) } // set shared interrupt @@ -466,13 +478,19 @@ func (w *worker) retryPayloadUpdate(args *BuildPayloadArgs, payload *Payload) { r := w.getSealingBlock(fullParams) if r.err != nil { log.Error("Failed to build full payload after fix", "id", payload.id, "err", r.err) - return + return fmt.Errorf("failed to build full payload after fix: %w", r.err) } payload.update(r, 0, func() { w.cacheMiningBlock(r.block, r.env) }) + + if r.err == nil { + fullParams.isUpdate = true + } + log.Info("Successfully updated payload after fix", "id", payload.id) + return nil } func (w *worker) cacheMiningBlock(block *types.Block, env *environment) { diff --git a/miner/worker_test.go b/miner/worker_test.go index aa05565301..c84a03cd01 100644 --- a/miner/worker_test.go +++ b/miner/worker_test.go @@ -34,6 +34,7 @@ import ( "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/params" @@ -144,8 +145,9 @@ func newTestWorkerBackend(t *testing.T, chainConfig *params.ChainConfig, engine } } -func (b *testWorkerBackend) BlockChain() *core.BlockChain { return b.chain } -func (b *testWorkerBackend) TxPool() *txpool.TxPool { return b.txPool } +func (b *testWorkerBackend) BlockChain() *core.BlockChain { return b.chain } +func (b *testWorkerBackend) TxPool() *txpool.TxPool { return b.txPool } +func (b *testWorkerBackend) Downloader() *downloader.Downloader { return nil } func (b *testWorkerBackend) newRandomTx(creation bool) *types.Transaction { var tx *types.Transaction