From 7e9b7968b36b9b4ed1cedc1d6bed3f0c1092f2cf Mon Sep 17 00:00:00 2001 From: Calvin Kim Date: Tue, 2 Apr 2024 17:56:12 +0900 Subject: [PATCH] blockchain: Add InvalidateBlock() method to BlockChain InvalidateBlock() invalidates a given block and marks all its descendents as invalid as well. The active chain tip changes if the invalidated block is part of the best chain. --- blockchain/chain.go | 138 ++++++++++++++ blockchain/chain_test.go | 385 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 523 insertions(+) diff --git a/blockchain/chain.go b/blockchain/chain.go index 7e06e5c77c..8e75b447e9 100644 --- a/blockchain/chain.go +++ b/blockchain/chain.go @@ -1798,6 +1798,144 @@ func (b *BlockChain) LocateHeaders(locator BlockLocator, hashStop *chainhash.Has return headers } +// InvalidateBlock invalidates the requested block and all its descedents. If a block +// in the best chain is invalidated, the active chain tip will be the parent of the +// invalidated block. +// +// This function is safe for concurrent access. +func (b *BlockChain) InvalidateBlock(hash *chainhash.Hash) error { + b.chainLock.Lock() + defer b.chainLock.Unlock() + + node := b.index.LookupNode(hash) + if node == nil { + // Return an error if the block doesn't exist. + return fmt.Errorf("Requested block hash of %s is not found "+ + "and thus cannot be invalidated.", hash) + } + if node.height == 0 { + return fmt.Errorf("Requested block hash of %s is a at height 0 "+ + "and is thus a genesis block and cannot be invalidated.", + node.hash) + } + + // Nothing to do if the given block is already invalid. + if node.status.KnownInvalid() { + return nil + } + + // Set the status of the block being invalidated. + b.index.SetStatusFlags(node, statusValidateFailed) + b.index.UnsetStatusFlags(node, statusValid) + + // If the block we're invalidating is not on the best chain, we simply + // mark the block and all its descendants as invalid and return. + if !b.bestChain.Contains(node) { + // Grab all the tips excluding the active tip. + tips := b.index.InactiveTips(b.bestChain) + for _, tip := range tips { + // Continue if the given inactive tip is not a descendant of the block + // being invalidated. + if !tip.IsAncestor(node) { + continue + } + + // Keep going back until we get to the block being invalidated. + // For each of the parent, we'll unset valid status and set invalid + // ancestor status. + for n := tip; n != nil && n != node; n = n.parent { + // Continue if it's already invalid. + if n.status.KnownInvalid() { + continue + } + b.index.SetStatusFlags(n, statusInvalidAncestor) + b.index.UnsetStatusFlags(n, statusValid) + } + } + + if writeErr := b.index.flushToDB(); writeErr != nil { + return fmt.Errorf("Error flushing block index "+ + "changes to disk: %v", writeErr) + } + + // Return since the block being invalidated is on a side branch. + // Nothing else left to do. + return nil + } + + // If we're here, it means a block from the active chain tip is getting + // invalidated. + // + // Grab all the nodes to detach from the active chain. + detachNodes := list.New() + for n := b.bestChain.Tip(); n != nil && n != node; n = n.parent { + // Continue if it's already invalid. + if n.status.KnownInvalid() { + continue + } + + // Change the status of the block node. + b.index.SetStatusFlags(n, statusInvalidAncestor) + b.index.UnsetStatusFlags(n, statusValid) + detachNodes.PushBack(n) + } + + // Push back the block node being invalidated. + detachNodes.PushBack(node) + + // Reorg back to the parent of the block being invalidated. + // Nothing to attach so just pass an empty list. + err := b.reorganizeChain(detachNodes, list.New()) + if err != nil { + return err + } + + if writeErr := b.index.flushToDB(); writeErr != nil { + log.Warnf("Error flushing block index changes to disk: %v", writeErr) + } + + // Grab all the tips. + tips := b.index.InactiveTips(b.bestChain) + tips = append(tips, b.bestChain.Tip()) + + // Here we'll check if the invalidation of the block in the active tip + // changes the status of the chain tips. If a side branch now has more + // worksum, it becomes the active chain tip. + var bestTip *blockNode + for _, tip := range tips { + // Skip invalid tips as they cannot become the active tip. + if tip.status.KnownInvalid() { + continue + } + + // If we have no best tips, then set this tip as the best tip. + if bestTip == nil { + bestTip = tip + } else { + // If there is an existing best tip, then compare it + // against the current tip. + if tip.workSum.Cmp(bestTip.workSum) == 1 { + bestTip = tip + } + } + } + + // Return if the best tip is the current tip. + if bestTip == b.bestChain.Tip() { + return nil + } + + // Reorganize to the best tip if a side branch is now the most work tip. + detachNodes, attachNodes := b.getReorganizeNodes(bestTip) + err = b.reorganizeChain(detachNodes, attachNodes) + + if writeErr := b.index.flushToDB(); writeErr != nil { + log.Warnf("Error flushing block index changes to disk: %v", writeErr) + } + + return err +} + // IndexManager provides a generic interface that the is called when blocks are // connected and disconnected to and from the tip of the main chain for the // purpose of supporting optional indexes. diff --git a/blockchain/chain_test.go b/blockchain/chain_test.go index 259a643f3c..52e86546a9 100644 --- a/blockchain/chain_test.go +++ b/blockchain/chain_test.go @@ -6,10 +6,12 @@ package blockchain import ( "fmt" + "math/rand" "reflect" "testing" "time" + "github.com/btcsuite/btcd/blockchain/testhelper" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" @@ -1311,3 +1313,386 @@ func TestIsAncestor(t *testing.T) { branch2Nodes[0].hash.String()) } } + +func TestInvalidateBlock(t *testing.T) { + tests := []struct { + name string + chainGen func() (*BlockChain, []*chainhash.Hash, func()) + }{ + { + name: "one branch, invalidate once", + chainGen: func() (*BlockChain, []*chainhash.Hash, func()) { + source := rand.NewSource(time.Now().UnixNano()) + rand := rand.New(source) + + chain, params, tearDown := utxoCacheTestChain( + "TestInvalidateBlock-one-branch-" + + "invalidate-once") + + tip := btcutil.NewBlock(params.GenesisBlock) + + // Create block at height 1. + var emptySpendableOuts []*testhelper.SpendableOut + b1, spendableOuts1, err := addBlock(chain, tip, emptySpendableOuts) + if err != nil { + t.Fatal(err) + } + + var allSpends []*testhelper.SpendableOut + nextBlock := b1 + nextSpends := spendableOuts1 + + var invalidateHash *chainhash.Hash + + // Create a chain with 11 blocks. + for b := 0; b < 10; b++ { + newBlock, newSpendableOuts, err := addBlock(chain, nextBlock, nextSpends) + if err != nil { + t.Fatal(err) + } + + nextBlock = newBlock + + if newBlock.Height() == 5 { + invalidateHash = newBlock.Hash() + } + + allSpends = append(allSpends, newSpendableOuts...) + + // Randomly grab utxos from allSpends that we'll be spending in the next + // block. + var nextSpendsTmp []*testhelper.SpendableOut + for i := 0; i < len(allSpends); i++ { + randIdx := rand.Intn(len(allSpends)) + + spend := allSpends[randIdx] // get + allSpends = append(allSpends[:randIdx], allSpends[randIdx+1:]...) // delete + nextSpendsTmp = append(nextSpendsTmp, spend) + } + nextSpends = nextSpendsTmp + + if b%10 == 0 { + // Commit the two base blocks to DB + if err := chain.FlushUtxoCache(FlushRequired); err != nil { + t.Fatalf("unexpected error while flushing cache: %v", err) + } + } + } + + return chain, []*chainhash.Hash{invalidateHash}, tearDown + }, + }, + { + name: "invalidate twice", + chainGen: func() (*BlockChain, []*chainhash.Hash, func()) { + source := rand.NewSource(0) + rand := rand.New(source) + + chain, params, tearDown := utxoCacheTestChain("TestInvalidateBlock-invalidate-twice") + tip := btcutil.NewBlock(params.GenesisBlock) + + // Create block at height 1. + var emptySpendableOuts []*testhelper.SpendableOut + b1, spendableOuts1, err := addBlock(chain, tip, emptySpendableOuts) + if err != nil { + t.Fatal(err) + } + + var allSpends []*testhelper.SpendableOut + nextBlock := b1 + nextSpends := spendableOuts1 + + var invalidateHash *chainhash.Hash + + // Create a chain with 11 blocks. + for b := 0; b < 10; b++ { + newBlock, newSpendableOuts, err := addBlock(chain, nextBlock, nextSpends) + if err != nil { + t.Fatal(err) + } + + nextBlock = newBlock + + if newBlock.Height() == 5 { + invalidateHash = newBlock.Hash() + } + + allSpends = append(allSpends, newSpendableOuts...) + + var nextSpendsTmp []*testhelper.SpendableOut + for i := 0; i < len(allSpends); i++ { + randIdx := rand.Intn(len(allSpends)) + + spend := allSpends[randIdx] // get + allSpends = append(allSpends[:randIdx], allSpends[randIdx+1:]...) // delete + nextSpendsTmp = append(nextSpendsTmp, spend) + } + nextSpends = nextSpendsTmp + + if b%10 == 0 { + // Commit the two base blocks to DB + if err := chain.FlushUtxoCache(FlushRequired); err != nil { + t.Fatalf("unexpected error while flushing cache: %v", err) + } + } + } + + // Create a side chain with 7 blocks that builds on block 1. + var altSpends []*testhelper.SpendableOut + altNextSpends := spendableOuts1 + altNextBlock := b1 + var invalidateHash2 *chainhash.Hash + for b := 0; b < 6; b++ { + altNewBlock, newSpends, err := addBlock(chain, altNextBlock, altNextSpends) + if err != nil { + t.Fatal(err) + } + + altNextBlock = altNewBlock + + altSpends = append(altSpends, newSpends...) + + if altNewBlock.Height() == 5 { + invalidateHash2 = altNewBlock.Hash() + } + + var nextSpendsTmp []*testhelper.SpendableOut + for i := 0; i < len(altSpends); i++ { + randIdx := rand.Intn(len(altSpends)) + + spend := altSpends[randIdx] // get + altSpends = append(altSpends[:randIdx], altSpends[randIdx+1:]...) // delete + nextSpendsTmp = append(nextSpendsTmp, spend) + } + altNextSpends = nextSpendsTmp + + if b%10 == 0 { + // Commit the two base blocks to DB + if err := chain.FlushUtxoCache(FlushRequired); err != nil { + t.Fatalf("unexpected error while flushing cache: %v", err) + } + } + } + + return chain, []*chainhash.Hash{invalidateHash, invalidateHash2}, tearDown + }, + }, + { + name: "invalidate a side branch", + chainGen: func() (*BlockChain, []*chainhash.Hash, func()) { + source := rand.NewSource(0) + rand := rand.New(source) + + chain, params, tearDown := utxoCacheTestChain("TestInvalidateBlock-invalidate-side-branch") + tip := btcutil.NewBlock(params.GenesisBlock) + + // Create block at height 1. + var emptySpendableOuts []*testhelper.SpendableOut + b1, spendableOuts1, err := addBlock(chain, tip, emptySpendableOuts) + if err != nil { + t.Fatal(err) + } + + var allSpends []*testhelper.SpendableOut + nextBlock := b1 + nextSpends := spendableOuts1 + + // Create a chain with 11 blocks. + for b := 0; b < 10; b++ { + newBlock, newSpendableOuts, err := addBlock(chain, nextBlock, nextSpends) + if err != nil { + t.Fatal(err) + } + + nextBlock = newBlock + + allSpends = append(allSpends, newSpendableOuts...) + + var nextSpendsTmp []*testhelper.SpendableOut + for i := 0; i < len(allSpends); i++ { + randIdx := rand.Intn(len(allSpends)) + + spend := allSpends[randIdx] // get + allSpends = append(allSpends[:randIdx], allSpends[randIdx+1:]...) // delete + nextSpendsTmp = append(nextSpendsTmp, spend) + } + nextSpends = nextSpendsTmp + + if b%10 == 0 { + // Commit the two base blocks to DB + if err := chain.FlushUtxoCache(FlushRequired); err != nil { + t.Fatalf("unexpected error while flushing cache: %v", err) + } + } + } + + // Create a side chain with 7 blocks that builds on block 1. + var altSpends []*testhelper.SpendableOut + altNextSpends := spendableOuts1 + altNextBlock := b1 + var invalidateHash *chainhash.Hash + for b := 0; b < 6; b++ { + altNewBlock, newSpends, err := addBlock(chain, altNextBlock, altNextSpends) + if err != nil { + t.Fatal(err) + } + + altNextBlock = altNewBlock + + altSpends = append(altSpends, newSpends...) + + if altNewBlock.Height() == 4 { + invalidateHash = altNewBlock.Hash() + } + + var nextSpendsTmp []*testhelper.SpendableOut + for i := 0; i < len(altSpends); i++ { + randIdx := rand.Intn(len(altSpends)) + + spend := altSpends[randIdx] // get + altSpends = append(altSpends[:randIdx], altSpends[randIdx+1:]...) // delete + nextSpendsTmp = append(nextSpendsTmp, spend) + } + altNextSpends = nextSpendsTmp + + if b%10 == 0 { + // Commit the two base blocks to DB + if err := chain.FlushUtxoCache(FlushRequired); err != nil { + t.Fatalf("unexpected error while flushing cache: %v", err) + } + } + } + + return chain, []*chainhash.Hash{invalidateHash}, tearDown + }, + }, + } + + for _, test := range tests { + chain, invalidateHashes, tearDown := test.chainGen() + func() { + defer tearDown() + for _, invalidateHash := range invalidateHashes { + chainTipsBefore := chain.ChainTips() + + // Mark if we're invalidating a block that's a part of the best chain. + var bestChainBlock bool + node := chain.index.LookupNode(invalidateHash) + if chain.bestChain.Contains(node) { + bestChainBlock = true + } + + // Actual invalidation. + err := chain.InvalidateBlock(invalidateHash) + if err != nil { + t.Fatal(err) + } + + chainTipsAfter := chain.ChainTips() + + // Create a map for easy lookup. + chainTipMap := make(map[chainhash.Hash]ChainTip, len(chainTipsAfter)) + activeTipCount := 0 + for _, chainTip := range chainTipsAfter { + chainTipMap[chainTip.BlockHash] = chainTip + + if chainTip.Status == StatusActive { + activeTipCount++ + } + } + if activeTipCount != 1 { + t.Fatalf("TestInvalidateBlock fail. Expected "+ + "1 active chain tip but got %d", activeTipCount) + } + + bestTip := chain.bestChain.Tip() + + validForkCount := 0 + for _, tip := range chainTipsBefore { + // If the chaintip was an active tip and we invalidated a block + // in the active tip, assert that it's invalid now. + if bestChainBlock && tip.Status == StatusActive { + gotTip, found := chainTipMap[tip.BlockHash] + if !found { + t.Fatalf("TestInvalidateBlock fail. Expected "+ + "block %s not found in chaintips after "+ + "invalidateblock", tip.BlockHash.String()) + } + + if gotTip.Status != StatusInvalid { + t.Fatalf("TestInvalidateBlock fail. "+ + "Expected block %s to be invalid, got status: %s", + gotTip.BlockHash.String(), gotTip.Status) + } + } + + if !bestChainBlock && tip.Status != StatusActive { + gotTip, found := chainTipMap[tip.BlockHash] + if !found { + t.Fatalf("TestInvalidateBlock fail. Expected "+ + "block %s not found in chaintips after "+ + "invalidateblock", tip.BlockHash.String()) + } + + if gotTip.BlockHash == *invalidateHash && gotTip.Status != StatusInvalid { + t.Fatalf("TestInvalidateBlock fail. "+ + "Expected block %s to be invalid, got status: %s", + gotTip.BlockHash.String(), gotTip.Status) + } + } + + // If we're not invalidating the branch with an active tip, + // we expect the active tip to remain the same. + if !bestChainBlock && tip.Status == StatusActive && tip.BlockHash != bestTip.hash { + t.Fatalf("TestInvalidateBlock fail. Expected block %s as the tip but got %s", + tip.BlockHash.String(), bestTip.hash.String()) + } + + // If this tip is not invalid and not active, it should be + // lighter than the current best tip. + if tip.Status != StatusActive && tip.Status != StatusInvalid && + tip.Height > bestTip.height { + + tipNode := chain.index.LookupNode(&tip.BlockHash) + if bestTip.workSum.Cmp(tipNode.workSum) == -1 { + t.Fatalf("TestInvalidateBlock fail. Expected "+ + "block %s to be the active tip but block %s "+ + "was", tipNode.hash.String(), bestTip.hash.String()) + } + } + + if tip.Status == StatusValidFork { + validForkCount++ + } + } + + // If there are no other valid chain tips besides the active chaintip, + // we expect to have one more chain tip after the invalidate. + if validForkCount == 0 && len(chainTipsAfter) != len(chainTipsBefore)+1 { + t.Fatalf("TestInvalidateBlock fail. Expected %d chaintips but got %d", + len(chainTipsBefore)+1, len(chainTipsAfter)) + } + } + + // Try to invaliate the already invalidated hash. + err := chain.InvalidateBlock(invalidateHashes[0]) + if err != nil { + t.Fatal(err) + } + + // Try to invaliate a genesis block + err = chain.InvalidateBlock(chain.chainParams.GenesisHash) + if err == nil { + t.Fatalf("TestInvalidateBlock fail. Expected to err when trying to" + + "invalidate a genesis block.") + } + + // Try to invaliate a block that doesn't exist. + err = chain.InvalidateBlock(chaincfg.MainNetParams.GenesisHash) + if err == nil { + t.Fatalf("TestInvalidateBlock fail. Expected to err when trying to" + + "invalidate a block that doesn't exist.") + } + }() + } +}