diff --git a/contracts b/contracts index accdcee457..436e1cf82c 160000 --- a/contracts +++ b/contracts @@ -1 +1 @@ -Subproject commit accdcee45798af5025836a04ee5bdcb0669cb476 +Subproject commit 436e1cf82c5696eb918d842256328ba86fbe5019 diff --git a/nodeInterface/NodeInterface.go b/nodeInterface/NodeInterface.go index 92ed2064c3..98394f9343 100644 --- a/nodeInterface/NodeInterface.go +++ b/nodeInterface/NodeInterface.go @@ -591,82 +591,82 @@ func (n NodeInterface) LegacyLookupMessageBatchProof(c ctx, evm mech, batchNum h return } -func (n NodeInterface) getL1BlockNum(l2BlockNum uint64) (uint64, error) { +func (n NodeInterface) blockL1Num(l2BlockNum uint64) (uint64, error) { blockHeader, err := n.backend.HeaderByNumber(n.context, rpc.BlockNumber(l2BlockNum)) if err != nil { return 0, err } - l1BlockNum := types.DeserializeHeaderExtraInformation(blockHeader).L1BlockNumber - return l1BlockNum, nil + blockL1Num := types.DeserializeHeaderExtraInformation(blockHeader).L1BlockNumber + return blockL1Num, nil } -func (n NodeInterface) GetL2BlockRangeForL1(c ctx, evm mech, l1BlockNum uint64) ([]uint64, error) { +func (n NodeInterface) matchL2BlockNumWithL1(l2BlockNum uint64, l1BlockNum uint64) error { + blockL1Num, err := n.blockL1Num(l2BlockNum) + if err != nil { + return fmt.Errorf("failed to get the L1 block number of the L2 block: %v. Error: %w", l2BlockNum, err) + } + if blockL1Num != l1BlockNum { + return fmt.Errorf("no L2 block was found with the given L1 block number. Found L2 block: %v with L1 block number: %v, given L1 block number: %v", l2BlockNum, blockL1Num, l1BlockNum) + } + return nil +} + +// L2BlockRangeForL1 finds the first and last L2 block numbers that have the given L1 block number +func (n NodeInterface) L2BlockRangeForL1(c ctx, evm mech, l1BlockNum uint64) (uint64, uint64, error) { currentBlockNum := n.backend.CurrentBlock().Number.Uint64() genesis := n.backend.ChainConfig().ArbitrumChainParams.GenesisBlockNum - checkCorrectness := func(blockNum uint64, target uint64) error { - blockL1Num, err := n.getL1BlockNum(blockNum) - if err != nil { - return err - } - if blockL1Num != target { - return errors.New("no L2 block was found with the given L1 block number") - } - return nil + type helperStruct struct { + low uint64 + high uint64 } - lowFirstBlock := genesis - highFirstBlock := currentBlockNum - lowLastBlock := genesis - highLastBlock := currentBlockNum - var storedMid uint64 - var storedMidBlockL1Num uint64 - for lowFirstBlock < highFirstBlock || lowLastBlock < highLastBlock { - if lowFirstBlock < highFirstBlock { - mid := arbmath.SaturatingUAdd(lowFirstBlock, highFirstBlock) / 2 - midBlockL1Num, err := n.getL1BlockNum(mid) - if err != nil { - return nil, err - } - storedMid = mid - storedMidBlockL1Num = midBlockL1Num - if midBlockL1Num < l1BlockNum { - lowFirstBlock = mid + 1 - } else { - highFirstBlock = mid - } - } - if lowLastBlock < highLastBlock { + searchHelper := func(currentBlock *helperStruct, fetchedMid *helperStruct, target uint64) error { + if currentBlock.low < currentBlock.high { // dont fetch midBlockL1Num if its already fetched above - mid := arbmath.SaturatingUAdd(lowLastBlock, highLastBlock) / 2 + mid := arbmath.SaturatingUAdd(currentBlock.low, currentBlock.high) / 2 var midBlockL1Num uint64 var err error - if mid == storedMid { - midBlockL1Num = storedMidBlockL1Num + if mid == fetchedMid.low { + midBlockL1Num = fetchedMid.high } else { - midBlockL1Num, err = n.getL1BlockNum(mid) + midBlockL1Num, err = n.blockL1Num(mid) if err != nil { - return nil, err + return err } + fetchedMid.low = mid + fetchedMid.high = midBlockL1Num } - if midBlockL1Num < l1BlockNum+1 { - lowLastBlock = mid + 1 + if midBlockL1Num < target { + currentBlock.low = mid + 1 } else { - highLastBlock = mid + currentBlock.high = mid } + return nil } + return nil } - err := checkCorrectness(highFirstBlock, l1BlockNum) - if err != nil { - return nil, err + firstBlock := &helperStruct{low: genesis, high: currentBlockNum} + lastBlock := &helperStruct{low: genesis, high: currentBlockNum} + // in storedMid low corresponds to value mid and high corresponds to midBlockL1Num inside searchHelper + storedMid := &helperStruct{low: currentBlockNum + 1} + var err error + for firstBlock.low < firstBlock.high || lastBlock.low < lastBlock.high { + if err = searchHelper(firstBlock, storedMid, l1BlockNum); err != nil { + return 0, 0, err + } + if err = searchHelper(lastBlock, storedMid, l1BlockNum+1); err != nil { + return 0, 0, err + } } - err = checkCorrectness(highLastBlock, l1BlockNum) - if err != nil { - highLastBlock -= 1 - err = checkCorrectness(highLastBlock, l1BlockNum) - if err != nil { - return nil, err + if err := n.matchL2BlockNumWithL1(firstBlock.high, l1BlockNum); err != nil { + return 0, 0, err + } + if err := n.matchL2BlockNumWithL1(lastBlock.high, l1BlockNum); err != nil { + lastBlock.high -= 1 + if err = n.matchL2BlockNumWithL1(lastBlock.high, l1BlockNum); err != nil { + return 0, 0, err } } - return []uint64{highFirstBlock, highLastBlock}, nil + return firstBlock.high, lastBlock.high, nil } diff --git a/system_tests/nodeinterface_test.go b/system_tests/nodeinterface_test.go index 266b50d6c8..bfdff3d02d 100644 --- a/system_tests/nodeinterface_test.go +++ b/system_tests/nodeinterface_test.go @@ -10,19 +10,11 @@ import ( "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/ethclient" "github.com/offchainlabs/nitro/arbos/util" "github.com/offchainlabs/nitro/solgen/go/node_interfacegen" ) -func getL1BlockNum(t *testing.T, ctx context.Context, client *ethclient.Client, l2BlockNum uint64) uint64 { - header, err := client.HeaderByNumber(ctx, big.NewInt(int64(l2BlockNum))) - Require(t, err) - l1BlockNum := types.DeserializeHeaderExtraInformation(header).L1BlockNumber - return l1BlockNum -} - -func TestGetL2BlockRangeForL1(t *testing.T) { +func TestL2BlockRangeForL1(t *testing.T) { t.Parallel() ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -32,7 +24,7 @@ func TestGetL2BlockRangeForL1(t *testing.T) { defer node.StopAndWait() user := l1info.GetDefaultTransactOpts("User", ctx) - numTransactions := 30 + numTransactions := 200 for i := 0; i < numTransactions; i++ { TransferBalanceTo(t, "Owner", util.RemapL1Address(user.From), big.NewInt(1e18), l2info, l2client, ctx) } @@ -40,31 +32,42 @@ func TestGetL2BlockRangeForL1(t *testing.T) { nodeInterface, err := node_interfacegen.NewNodeInterface(types.NodeInterfaceAddress, l2client) Require(t, err) + getBlockL1Num := func(l2BlockNum uint64) uint64 { + header, err := l2client.HeaderByNumber(ctx, big.NewInt(int64(l2BlockNum))) + Require(t, err) + l1BlockNum := types.DeserializeHeaderExtraInformation(header).L1BlockNumber + return l1BlockNum + } + l1BlockNums := map[uint64][]uint64{} latestL2, err := l2client.BlockNumber(ctx) Require(t, err) for l2BlockNum := uint64(0); l2BlockNum <= latestL2; l2BlockNum++ { - l1BlockNum := getL1BlockNum(t, ctx, l2client, l2BlockNum) - l1BlockNums[l1BlockNum] = append(l1BlockNums[l1BlockNum], l2BlockNum) + l1BlockNum := getBlockL1Num(l2BlockNum) + if len(l1BlockNums[l1BlockNum]) <= 1 { + l1BlockNums[l1BlockNum] = append(l1BlockNums[l1BlockNum], l2BlockNum) + } else { + l1BlockNums[l1BlockNum][1] = l2BlockNum + } } // Test success for l1BlockNum := range l1BlockNums { - rng, err := nodeInterface.GetL2BlockRangeForL1(&bind.CallOpts{}, l1BlockNum) + rng, err := nodeInterface.L2BlockRangeForL1(&bind.CallOpts{}, l1BlockNum) Require(t, err) n := len(l1BlockNums[l1BlockNum]) expected := []uint64{l1BlockNums[l1BlockNum][0], l1BlockNums[l1BlockNum][n-1]} - if expected[0] != rng[0] || expected[1] != rng[1] { - unexpectedL1BlockNum := getL1BlockNum(t, ctx, l2client, rng[1]) - // handle the edge case when new l2 blocks are produced between latestL2 was last calculated and now - if unexpectedL1BlockNum != l1BlockNum { - t.Fatalf("GetL2BlockRangeForL1 failed to get a valid range for L1 block number: %v. Given range: %v. Expected range: %v", l1BlockNum, rng, expected) + if expected[0] != rng.FirstBlock || expected[1] != rng.LastBlock { + unexpectedL1BlockNum := getBlockL1Num(rng.LastBlock) + // Handle the edge case when new l2 blocks are produced between latestL2 was last calculated and now. + if unexpectedL1BlockNum != l1BlockNum || rng.LastBlock < expected[1] { + t.Errorf("L2BlockRangeForL1(%d) = (%d %d) want (%d %d)", l1BlockNum, rng.FirstBlock, rng.LastBlock, expected[0], expected[1]) } } } // Test invalid case - finalValidL1BlockNumber := getL1BlockNum(t, ctx, l2client, latestL2) - _, err = nodeInterface.GetL2BlockRangeForL1(&bind.CallOpts{}, finalValidL1BlockNumber+1) + finalValidL1BlockNumber := getBlockL1Num(latestL2) + _, err = nodeInterface.L2BlockRangeForL1(&bind.CallOpts{}, finalValidL1BlockNumber+1) if err == nil { t.Fatalf("GetL2BlockRangeForL1 didn't fail for an invalid input") }