diff --git a/arbnode/dataposter/data_poster.go b/arbnode/dataposter/data_poster.go index 7bc18a2121..fb35ac3c8d 100644 --- a/arbnode/dataposter/data_poster.go +++ b/arbnode/dataposter/data_poster.go @@ -217,6 +217,10 @@ func NewDataPoster(ctx context.Context, opts *DataPosterOpts) (*DataPoster, erro func rpcClient(ctx context.Context, opts *ExternalSignerCfg) (*rpc.Client, error) { tlsCfg := &tls.Config{ MinVersion: tls.VersionTLS12, + // Dataposter verifies that signed transaction was signed by the account + // that it expects to be signed with. So signer is already authenticated + // on application level and does not need to rely on TLS for authentication. + InsecureSkipVerify: opts.InsecureSkipVerify, // #nosec G402 } if opts.ClientCert != "" && opts.ClientPrivateKey != "" { @@ -1223,6 +1227,8 @@ type ExternalSignerCfg struct { // (Optional) Client certificate key for mtls. // This is required when client-cert is set. ClientPrivateKey string `koanf:"client-private-key"` + // TLS config option, when enabled skips certificate verification of external signer. + InsecureSkipVerify bool `koanf:"insecure-skip-verify"` } type DangerousConfig struct { @@ -1276,6 +1282,7 @@ func addExternalSignerOptions(prefix string, f *pflag.FlagSet) { f.String(prefix+".root-ca", DefaultDataPosterConfig.ExternalSigner.RootCA, "external signer root CA") f.String(prefix+".client-cert", DefaultDataPosterConfig.ExternalSigner.ClientCert, "rpc client cert") f.String(prefix+".client-private-key", DefaultDataPosterConfig.ExternalSigner.ClientPrivateKey, "rpc client private key") + f.Bool(prefix+".insecure-skip-verify", DefaultDataPosterConfig.ExternalSigner.InsecureSkipVerify, "skip TLS certificate verification") } var DefaultDataPosterConfig = DataPosterConfig{ @@ -1297,7 +1304,7 @@ var DefaultDataPosterConfig = DataPosterConfig{ UseNoOpStorage: false, LegacyStorageEncoding: false, Dangerous: DangerousConfig{ClearDBStorage: false}, - ExternalSigner: ExternalSignerCfg{Method: "eth_signTransaction"}, + ExternalSigner: ExternalSignerCfg{Method: "eth_signTransaction", InsecureSkipVerify: false}, MaxFeeCapFormula: "((BacklogOfBatches * UrgencyGWei) ** 2) + ((ElapsedTime/ElapsedTimeBase) ** 2) * ElapsedTimeImportance + TargetPriceGWei", ElapsedTimeBase: 10 * time.Minute, ElapsedTimeImportance: 10, @@ -1330,7 +1337,7 @@ var TestDataPosterConfig = DataPosterConfig{ UseDBStorage: false, UseNoOpStorage: false, LegacyStorageEncoding: false, - ExternalSigner: ExternalSignerCfg{Method: "eth_signTransaction"}, + ExternalSigner: ExternalSignerCfg{Method: "eth_signTransaction", InsecureSkipVerify: true}, MaxFeeCapFormula: "((BacklogOfBatches * UrgencyGWei) ** 2) + ((ElapsedTime/ElapsedTimeBase) ** 2) * ElapsedTimeImportance + TargetPriceGWei", ElapsedTimeBase: 10 * time.Minute, ElapsedTimeImportance: 10, diff --git a/arbnode/message_pruner.go b/arbnode/message_pruner.go index 31bf1a63ff..5d18341a27 100644 --- a/arbnode/message_pruner.go +++ b/arbnode/message_pruner.go @@ -23,13 +23,14 @@ import ( type MessagePruner struct { stopwaiter.StopWaiter - transactionStreamer *TransactionStreamer - inboxTracker *InboxTracker - config MessagePrunerConfigFetcher - pruningLock sync.Mutex - lastPruneDone time.Time - cachedPrunedMessages uint64 - cachedPrunedDelayedMessages uint64 + transactionStreamer *TransactionStreamer + inboxTracker *InboxTracker + config MessagePrunerConfigFetcher + pruningLock sync.Mutex + lastPruneDone time.Time + cachedPrunedMessages uint64 + cachedPrunedBlockHashesInputFeed uint64 + cachedPrunedDelayedMessages uint64 } type MessagePrunerConfig struct { @@ -115,7 +116,15 @@ func (m *MessagePruner) prune(ctx context.Context, count arbutil.MessageIndex, g } func (m *MessagePruner) deleteOldMessagesFromDB(ctx context.Context, messageCount arbutil.MessageIndex, delayedMessageCount uint64) error { - prunedKeysRange, err := deleteFromLastPrunedUptoEndKey(ctx, m.transactionStreamer.db, messagePrefix, &m.cachedPrunedMessages, uint64(messageCount)) + prunedKeysRange, err := deleteFromLastPrunedUptoEndKey(ctx, m.transactionStreamer.db, blockHashInputFeedPrefix, &m.cachedPrunedBlockHashesInputFeed, uint64(messageCount)) + if err != nil { + return fmt.Errorf("error deleting expected block hashes: %w", err) + } + if len(prunedKeysRange) > 0 { + log.Info("Pruned expected block hashes:", "first pruned key", prunedKeysRange[0], "last pruned key", prunedKeysRange[len(prunedKeysRange)-1]) + } + + prunedKeysRange, err = deleteFromLastPrunedUptoEndKey(ctx, m.transactionStreamer.db, messagePrefix, &m.cachedPrunedMessages, uint64(messageCount)) if err != nil { return fmt.Errorf("error deleting last batch messages: %w", err) } diff --git a/arbnode/message_pruner_test.go b/arbnode/message_pruner_test.go index 0212ed2364..ed85c0ebce 100644 --- a/arbnode/message_pruner_test.go +++ b/arbnode/message_pruner_test.go @@ -22,8 +22,8 @@ func TestMessagePrunerWithPruningEligibleMessagePresent(t *testing.T) { Require(t, err) checkDbKeys(t, messagesCount, transactionStreamerDb, messagePrefix) + checkDbKeys(t, messagesCount, transactionStreamerDb, blockHashInputFeedPrefix) checkDbKeys(t, messagesCount, inboxTrackerDb, rlpDelayedMessagePrefix) - } func TestMessagePrunerTwoHalves(t *testing.T) { @@ -71,16 +71,18 @@ func TestMessagePrunerWithNoPruningEligibleMessagePresent(t *testing.T) { Require(t, err) checkDbKeys(t, uint64(messagesCount), transactionStreamerDb, messagePrefix) + checkDbKeys(t, uint64(messagesCount), transactionStreamerDb, blockHashInputFeedPrefix) checkDbKeys(t, messagesCount, inboxTrackerDb, rlpDelayedMessagePrefix) } func setupDatabase(t *testing.T, messageCount, delayedMessageCount uint64) (ethdb.Database, ethdb.Database, *MessagePruner) { - transactionStreamerDb := rawdb.NewMemoryDatabase() for i := uint64(0); i < uint64(messageCount); i++ { err := transactionStreamerDb.Put(dbKey(messagePrefix, i), []byte{}) Require(t, err) + err = transactionStreamerDb.Put(dbKey(blockHashInputFeedPrefix, i), []byte{}) + Require(t, err) } inboxTrackerDb := rawdb.NewMemoryDatabase() diff --git a/arbnode/schema.go b/arbnode/schema.go index ddc7cf54fd..2854b7e785 100644 --- a/arbnode/schema.go +++ b/arbnode/schema.go @@ -5,6 +5,7 @@ package arbnode var ( messagePrefix []byte = []byte("m") // maps a message sequence number to a message + blockHashInputFeedPrefix []byte = []byte("b") // maps a message sequence number to a block hash received through the input feed legacyDelayedMessagePrefix []byte = []byte("d") // maps a delayed sequence number to an accumulator and a message as serialized on L1 rlpDelayedMessagePrefix []byte = []byte("e") // maps a delayed sequence number to an accumulator and an RLP encoded message parentChainBlockNumberPrefix []byte = []byte("p") // maps a delayed sequence number to a parent chain block number diff --git a/arbnode/transaction_streamer.go b/arbnode/transaction_streamer.go index 0d5ae829b0..b79b1aa963 100644 --- a/arbnode/transaction_streamer.go +++ b/arbnode/transaction_streamer.go @@ -60,7 +60,7 @@ type TransactionStreamer struct { nextAllowedFeedReorgLog time.Time - broadcasterQueuedMessages []arbostypes.MessageWithMetadata + broadcasterQueuedMessages []arbostypes.MessageWithMetadataAndBlockHash broadcasterQueuedMessagesPos uint64 broadcasterQueuedMessagesActiveReorg bool @@ -140,6 +140,16 @@ type L1PriceData struct { currentEstimateOfL1GasPrice uint64 } +// Represents a block's hash in the database. +// Necessary because RLP decoder doesn't produce nil values by default. +type blockHashDBValue struct { + BlockHash *common.Hash `rlp:"nil"` +} + +const ( + BlockHashMismatchLogMsg = "BlockHash from feed doesn't match locally computed hash. Check feed source." +) + func (s *TransactionStreamer) CurrentEstimateOfL1GasPrice() uint64 { s.cachedL1PriceDataMutex.Lock() defer s.cachedL1PriceDataMutex.Unlock() @@ -371,7 +381,7 @@ func deleteFromRange(ctx context.Context, db ethdb.Database, prefix []byte, star // The insertion mutex must be held. This acquires the reorg mutex. // Note: oldMessages will be empty if reorgHook is nil -func (s *TransactionStreamer) reorg(batch ethdb.Batch, count arbutil.MessageIndex, newMessages []arbostypes.MessageWithMetadata) error { +func (s *TransactionStreamer) reorg(batch ethdb.Batch, count arbutil.MessageIndex, newMessages []arbostypes.MessageWithMetadataAndBlockHash) error { if count == 0 { return errors.New("cannot reorg out init message") } @@ -465,14 +475,14 @@ func (s *TransactionStreamer) reorg(batch ethdb.Batch, count arbutil.MessageInde return err } - messagesWithBlockHash := make([]broadcaster.MessageWithMetadataAndBlockHash, 0, len(messagesResults)) + messagesWithComputedBlockHash := make([]arbostypes.MessageWithMetadataAndBlockHash, 0, len(messagesResults)) for i := 0; i < len(messagesResults); i++ { - messagesWithBlockHash = append(messagesWithBlockHash, broadcaster.MessageWithMetadataAndBlockHash{ - Message: newMessages[i], - BlockHash: &messagesResults[i].BlockHash, + messagesWithComputedBlockHash = append(messagesWithComputedBlockHash, arbostypes.MessageWithMetadataAndBlockHash{ + MessageWithMeta: newMessages[i].MessageWithMeta, + BlockHash: &messagesResults[i].BlockHash, }) } - s.broadcastMessages(messagesWithBlockHash, count) + s.broadcastMessages(messagesWithComputedBlockHash, count) if s.validator != nil { err = s.validator.Reorg(s.GetContext(), count) @@ -481,6 +491,10 @@ func (s *TransactionStreamer) reorg(batch ethdb.Batch, count arbutil.MessageInde } } + err = deleteStartingAt(s.db, batch, blockHashInputFeedPrefix, uint64ToKey(uint64(count))) + if err != nil { + return err + } err = deleteStartingAt(s.db, batch, messagePrefix, uint64ToKey(uint64(count))) if err != nil { return err @@ -510,6 +524,10 @@ func dbKey(prefix []byte, pos uint64) []byte { return key } +func isErrNotFound(err error) bool { + return errors.Is(err, leveldb.ErrNotFound) || errors.Is(err, pebble.ErrNotFound) +} + // Note: if changed to acquire the mutex, some internal users may need to be updated to a non-locking version. func (s *TransactionStreamer) GetMessage(seqNum arbutil.MessageIndex) (*arbostypes.MessageWithMetadata, error) { key := dbKey(messagePrefix, uint64(seqNum)) @@ -526,6 +544,36 @@ func (s *TransactionStreamer) GetMessage(seqNum arbutil.MessageIndex) (*arbostyp return &message, nil } +func (s *TransactionStreamer) getMessageWithMetadataAndBlockHash(seqNum arbutil.MessageIndex) (*arbostypes.MessageWithMetadataAndBlockHash, error) { + msg, err := s.GetMessage(seqNum) + if err != nil { + return nil, err + } + + // Get block hash. + // To keep it backwards compatible, since it is possible that a message related + // to a sequence number exists in the database, but the block hash doesn't. + key := dbKey(blockHashInputFeedPrefix, uint64(seqNum)) + var blockHash *common.Hash + data, err := s.db.Get(key) + if err == nil { + var blockHashDBVal blockHashDBValue + err = rlp.DecodeBytes(data, &blockHashDBVal) + if err != nil { + return nil, err + } + blockHash = blockHashDBVal.BlockHash + } else if !isErrNotFound(err) { + return nil, err + } + + msgWithBlockHash := arbostypes.MessageWithMetadataAndBlockHash{ + MessageWithMeta: *msg, + BlockHash: blockHash, + } + return &msgWithBlockHash, nil +} + // Note: if changed to acquire the mutex, some internal users may need to be updated to a non-locking version. func (s *TransactionStreamer) GetMessageCount() (arbutil.MessageIndex, error) { posBytes, err := s.db.Get(messageCountKey) @@ -579,7 +627,7 @@ func (s *TransactionStreamer) AddBroadcastMessages(feedMessages []*m.BroadcastFe return nil } broadcastStartPos := feedMessages[0].SequenceNumber - var messages []arbostypes.MessageWithMetadata + var messages []arbostypes.MessageWithMetadataAndBlockHash broadcastAfterPos := broadcastStartPos for _, feedMessage := range feedMessages { if broadcastAfterPos != feedMessage.SequenceNumber { @@ -588,7 +636,11 @@ func (s *TransactionStreamer) AddBroadcastMessages(feedMessages []*m.BroadcastFe if feedMessage.Message.Message == nil || feedMessage.Message.Message.Header == nil { return fmt.Errorf("invalid feed message at sequence number %v", feedMessage.SequenceNumber) } - messages = append(messages, feedMessage.Message) + msgWithBlockHash := arbostypes.MessageWithMetadataAndBlockHash{ + MessageWithMeta: feedMessage.Message, + BlockHash: feedMessage.BlockHash, + } + messages = append(messages, msgWithBlockHash) broadcastAfterPos++ } @@ -607,7 +659,7 @@ func (s *TransactionStreamer) AddBroadcastMessages(feedMessages []*m.BroadcastFe messages = messages[dups:] broadcastStartPos += arbutil.MessageIndex(dups) if oldMsg != nil { - s.logReorg(broadcastStartPos, oldMsg, &messages[0], false) + s.logReorg(broadcastStartPos, oldMsg, &messages[0].MessageWithMeta, false) } if len(messages) == 0 { // No new messages received @@ -657,7 +709,7 @@ func (s *TransactionStreamer) AddBroadcastMessages(feedMessages []*m.BroadcastFe if broadcastStartPos > 0 { _, err := s.GetMessage(broadcastStartPos - 1) if err != nil { - if !errors.Is(err, leveldb.ErrNotFound) && !errors.Is(err, pebble.ErrNotFound) { + if !isErrNotFound(err) { return err } // Message before current message doesn't exist in database, so don't add current messages yet @@ -709,11 +761,18 @@ func endBatch(batch ethdb.Batch) error { } func (s *TransactionStreamer) AddMessagesAndEndBatch(pos arbutil.MessageIndex, messagesAreConfirmed bool, messages []arbostypes.MessageWithMetadata, batch ethdb.Batch) error { + messagesWithBlockHash := make([]arbostypes.MessageWithMetadataAndBlockHash, 0, len(messages)) + for _, message := range messages { + messagesWithBlockHash = append(messagesWithBlockHash, arbostypes.MessageWithMetadataAndBlockHash{ + MessageWithMeta: message, + }) + } + if messagesAreConfirmed { // Trim confirmed messages from l1pricedataCache s.TrimCache(pos + arbutil.MessageIndex(len(messages))) s.reorgMutex.RLock() - dups, _, _, err := s.countDuplicateMessages(pos, messages, nil) + dups, _, _, err := s.countDuplicateMessages(pos, messagesWithBlockHash, nil) s.reorgMutex.RUnlock() if err != nil { return err @@ -730,7 +789,7 @@ func (s *TransactionStreamer) AddMessagesAndEndBatch(pos arbutil.MessageIndex, m s.insertionMutex.Lock() defer s.insertionMutex.Unlock() - return s.addMessagesAndEndBatchImpl(pos, messagesAreConfirmed, messages, batch) + return s.addMessagesAndEndBatchImpl(pos, messagesAreConfirmed, messagesWithBlockHash, batch) } func (s *TransactionStreamer) getPrevPrevDelayedRead(pos arbutil.MessageIndex) (uint64, error) { @@ -748,7 +807,7 @@ func (s *TransactionStreamer) getPrevPrevDelayedRead(pos arbutil.MessageIndex) ( func (s *TransactionStreamer) countDuplicateMessages( pos arbutil.MessageIndex, - messages []arbostypes.MessageWithMetadata, + messages []arbostypes.MessageWithMetadataAndBlockHash, batch *ethdb.Batch, ) (int, bool, *arbostypes.MessageWithMetadata, error) { curMsg := 0 @@ -769,7 +828,7 @@ func (s *TransactionStreamer) countDuplicateMessages( return 0, false, nil, err } nextMessage := messages[curMsg] - wantMessage, err := rlp.EncodeToBytes(nextMessage) + wantMessage, err := rlp.EncodeToBytes(nextMessage.MessageWithMeta) if err != nil { return 0, false, nil, err } @@ -785,12 +844,12 @@ func (s *TransactionStreamer) countDuplicateMessages( return curMsg, true, nil, nil } var duplicateMessage bool - if nextMessage.Message != nil { - if dbMessageParsed.Message.BatchGasCost == nil || nextMessage.Message.BatchGasCost == nil { + if nextMessage.MessageWithMeta.Message != nil { + if dbMessageParsed.Message.BatchGasCost == nil || nextMessage.MessageWithMeta.Message.BatchGasCost == nil { // Remove both of the batch gas costs and see if the messages still differ - nextMessageCopy := nextMessage + nextMessageCopy := nextMessage.MessageWithMeta nextMessageCopy.Message = new(arbostypes.L1IncomingMessage) - *nextMessageCopy.Message = *nextMessage.Message + *nextMessageCopy.Message = *nextMessage.MessageWithMeta.Message batchGasCostBkup := dbMessageParsed.Message.BatchGasCost dbMessageParsed.Message.BatchGasCost = nil nextMessageCopy.Message.BatchGasCost = nil @@ -798,7 +857,7 @@ func (s *TransactionStreamer) countDuplicateMessages( // Actually this isn't a reorg; only the batch gas costs differed duplicateMessage = true // If possible - update the message in the database to add the gas cost cache. - if batch != nil && nextMessage.Message.BatchGasCost != nil { + if batch != nil && nextMessage.MessageWithMeta.Message.BatchGasCost != nil { if *batch == nil { *batch = s.db.NewBatch() } @@ -842,7 +901,7 @@ func (s *TransactionStreamer) logReorg(pos arbutil.MessageIndex, dbMsg *arbostyp } -func (s *TransactionStreamer) addMessagesAndEndBatchImpl(messageStartPos arbutil.MessageIndex, messagesAreConfirmed bool, messages []arbostypes.MessageWithMetadata, batch ethdb.Batch) error { +func (s *TransactionStreamer) addMessagesAndEndBatchImpl(messageStartPos arbutil.MessageIndex, messagesAreConfirmed bool, messages []arbostypes.MessageWithMetadataAndBlockHash, batch ethdb.Batch) error { var confirmedReorg bool var oldMsg *arbostypes.MessageWithMetadata var lastDelayedRead uint64 @@ -860,7 +919,7 @@ func (s *TransactionStreamer) addMessagesAndEndBatchImpl(messageStartPos arbutil return err } if duplicates > 0 { - lastDelayedRead = messages[duplicates-1].DelayedMessagesRead + lastDelayedRead = messages[duplicates-1].MessageWithMeta.DelayedMessagesRead messages = messages[duplicates:] messageStartPos += arbutil.MessageIndex(duplicates) } @@ -898,13 +957,13 @@ func (s *TransactionStreamer) addMessagesAndEndBatchImpl(messageStartPos arbutil return err } if duplicates > 0 { - lastDelayedRead = messages[duplicates-1].DelayedMessagesRead + lastDelayedRead = messages[duplicates-1].MessageWithMeta.DelayedMessagesRead messages = messages[duplicates:] messageStartPos += arbutil.MessageIndex(duplicates) } } if oldMsg != nil { - s.logReorg(messageStartPos, oldMsg, &messages[0], confirmedReorg) + s.logReorg(messageStartPos, oldMsg, &messages[0].MessageWithMeta, confirmedReorg) } if feedReorg { @@ -924,12 +983,12 @@ func (s *TransactionStreamer) addMessagesAndEndBatchImpl(messageStartPos arbutil // Validate delayed message counts of remaining messages for i, msg := range messages { msgPos := messageStartPos + arbutil.MessageIndex(i) - diff := msg.DelayedMessagesRead - lastDelayedRead + diff := msg.MessageWithMeta.DelayedMessagesRead - lastDelayedRead if diff != 0 && diff != 1 { - return fmt.Errorf("attempted to insert jump from %v delayed messages read to %v delayed messages read at message index %v", lastDelayedRead, msg.DelayedMessagesRead, msgPos) + return fmt.Errorf("attempted to insert jump from %v delayed messages read to %v delayed messages read at message index %v", lastDelayedRead, msg.MessageWithMeta.DelayedMessagesRead, msgPos) } - lastDelayedRead = msg.DelayedMessagesRead - if msg.Message == nil { + lastDelayedRead = msg.MessageWithMeta.DelayedMessagesRead + if msg.MessageWithMeta.Message == nil { return fmt.Errorf("attempted to insert nil message at position %v", msgPos) } } @@ -1007,15 +1066,15 @@ func (s *TransactionStreamer) WriteMessageFromSequencer( } } - if err := s.writeMessages(pos, []arbostypes.MessageWithMetadata{msgWithMeta}, nil); err != nil { - return err + msgWithBlockHash := arbostypes.MessageWithMetadataAndBlockHash{ + MessageWithMeta: msgWithMeta, + BlockHash: &msgResult.BlockHash, } - msgWithBlockHash := broadcaster.MessageWithMetadataAndBlockHash{ - Message: msgWithMeta, - BlockHash: &msgResult.BlockHash, + if err := s.writeMessages(pos, []arbostypes.MessageWithMetadataAndBlockHash{msgWithBlockHash}, nil); err != nil { + return err } - s.broadcastMessages([]broadcaster.MessageWithMetadataAndBlockHash{msgWithBlockHash}, pos) + s.broadcastMessages([]arbostypes.MessageWithMetadataAndBlockHash{msgWithBlockHash}, pos) return nil } @@ -1036,9 +1095,23 @@ func (s *TransactionStreamer) PopulateFeedBacklog() error { return s.inboxReader.tracker.PopulateFeedBacklog(s.broadcastServer) } -func (s *TransactionStreamer) writeMessage(pos arbutil.MessageIndex, msg arbostypes.MessageWithMetadata, batch ethdb.Batch) error { +func (s *TransactionStreamer) writeMessage(pos arbutil.MessageIndex, msg arbostypes.MessageWithMetadataAndBlockHash, batch ethdb.Batch) error { + // write message with metadata key := dbKey(messagePrefix, uint64(pos)) - msgBytes, err := rlp.EncodeToBytes(msg) + msgBytes, err := rlp.EncodeToBytes(msg.MessageWithMeta) + if err != nil { + return err + } + if err := batch.Put(key, msgBytes); err != nil { + return err + } + + // write block hash + blockHashDBVal := blockHashDBValue{ + BlockHash: msg.BlockHash, + } + key = dbKey(blockHashInputFeedPrefix, uint64(pos)) + msgBytes, err = rlp.EncodeToBytes(blockHashDBVal) if err != nil { return err } @@ -1046,7 +1119,7 @@ func (s *TransactionStreamer) writeMessage(pos arbutil.MessageIndex, msg arbosty } func (s *TransactionStreamer) broadcastMessages( - msgs []broadcaster.MessageWithMetadataAndBlockHash, + msgs []arbostypes.MessageWithMetadataAndBlockHash, pos arbutil.MessageIndex, ) { if s.broadcastServer == nil { @@ -1059,7 +1132,7 @@ func (s *TransactionStreamer) broadcastMessages( // The mutex must be held, and pos must be the latest message count. // `batch` may be nil, which initializes a new batch. The batch is closed out in this function. -func (s *TransactionStreamer) writeMessages(pos arbutil.MessageIndex, messages []arbostypes.MessageWithMetadata, batch ethdb.Batch) error { +func (s *TransactionStreamer) writeMessages(pos arbutil.MessageIndex, messages []arbostypes.MessageWithMetadataAndBlockHash, batch ethdb.Batch) error { if batch == nil { batch = s.db.NewBatch() } @@ -1095,6 +1168,20 @@ func (s *TransactionStreamer) ResultAtCount(count arbutil.MessageIndex) (*execut return s.exec.ResultAtPos(count - 1) } +func (s *TransactionStreamer) checkResult(msgResult *execution.MessageResult, expectedBlockHash *common.Hash) { + if expectedBlockHash == nil { + return + } + if msgResult.BlockHash != *expectedBlockHash { + log.Error( + BlockHashMismatchLogMsg, + "expected", expectedBlockHash, + "actual", msgResult.BlockHash, + ) + return + } +} + // exposed for testing // return value: true if should be called again immediately func (s *TransactionStreamer) ExecuteNextMsg(ctx context.Context, exec execution.ExecutionSequencer) bool { @@ -1121,7 +1208,7 @@ func (s *TransactionStreamer) ExecuteNextMsg(ctx context.Context, exec execution if pos >= msgCount { return false } - msg, err := s.GetMessage(pos) + msgAndBlockHash, err := s.getMessageWithMetadataAndBlockHash(pos) if err != nil { log.Error("feedOneMsg failed to readMessage", "err", err, "pos", pos) return false @@ -1135,7 +1222,7 @@ func (s *TransactionStreamer) ExecuteNextMsg(ctx context.Context, exec execution } msgForPrefetch = msg } - msgResult, err := s.exec.DigestMessage(pos, msg, msgForPrefetch) + msgResult, err := s.exec.DigestMessage(pos, &msgAndBlockHash.MessageWithMeta, msgForPrefetch) if err != nil { logger := log.Warn if prevMessageCount < msgCount { @@ -1145,11 +1232,13 @@ func (s *TransactionStreamer) ExecuteNextMsg(ctx context.Context, exec execution return false } - msgWithBlockHash := broadcaster.MessageWithMetadataAndBlockHash{ - Message: *msg, - BlockHash: &msgResult.BlockHash, + s.checkResult(msgResult, msgAndBlockHash.BlockHash) + + msgWithBlockHash := arbostypes.MessageWithMetadataAndBlockHash{ + MessageWithMeta: msgAndBlockHash.MessageWithMeta, + BlockHash: &msgResult.BlockHash, } - s.broadcastMessages([]broadcaster.MessageWithMetadataAndBlockHash{msgWithBlockHash}, pos) + s.broadcastMessages([]arbostypes.MessageWithMetadataAndBlockHash{msgWithBlockHash}, pos) return pos+1 < msgCount } diff --git a/arbos/arbostypes/messagewithmeta.go b/arbos/arbostypes/messagewithmeta.go index a3d4f5e3c3..79b7c4f9d2 100644 --- a/arbos/arbostypes/messagewithmeta.go +++ b/arbos/arbostypes/messagewithmeta.go @@ -18,6 +18,11 @@ type MessageWithMetadata struct { DelayedMessagesRead uint64 `json:"delayedMessagesRead"` } +type MessageWithMetadataAndBlockHash struct { + MessageWithMeta MessageWithMetadata + BlockHash *common.Hash +} + var EmptyTestMessageWithMetadata = MessageWithMetadata{ Message: &EmptyTestIncomingMessage, } diff --git a/arbos/programs/data_pricer.go b/arbos/programs/data_pricer.go index b0184d7dc7..ed7c98556d 100644 --- a/arbos/programs/data_pricer.go +++ b/arbos/programs/data_pricer.go @@ -27,12 +27,14 @@ const ( inertiaOffset ) +const ArbitrumStartTime = 1421388000 // the day it all began + const initialDemand = 0 // no demand const InitialHourlyBytes = 1 * (1 << 40) / (365 * 24) // 1Tb total footprint const initialBytesPerSecond = InitialHourlyBytes / (60 * 60) // refill each second -const initialLastUpdateTime = 1421388000 // the day it all began -const initialMinPrice = 82928201 // 5Mb = $1 -const initialInertia = 21360419 // expensive at 1Tb +const initialLastUpdateTime = ArbitrumStartTime +const initialMinPrice = 82928201 // 5Mb = $1 +const initialInertia = 21360419 // expensive at 1Tb func initDataPricer(sto *storage.Storage) { demand := sto.OpenStorageBackedUint32(demandOffset) diff --git a/arbos/programs/programs.go b/arbos/programs/programs.go index d3113ae98d..6f73e16b85 100644 --- a/arbos/programs/programs.go +++ b/arbos/programs/programs.go @@ -527,12 +527,12 @@ func (status userStatus) toResult(data []byte, debug bool) ([]byte, string, erro // Hours since Arbitrum began, rounded down. func hoursSinceArbitrum(time uint64) uint24 { - return uint24((time - lastUpdateTimeOffset) / 3600) + return am.SaturatingUUCast[uint24]((am.SaturatingUSub(time, ArbitrumStartTime)) / 3600) } // Computes program age in seconds from the hours passed since Arbitrum began. func hoursToAge(time uint64, hours uint24) uint64 { seconds := am.SaturatingUMul(uint64(hours), 3600) - activatedAt := am.SaturatingUAdd(lastUpdateTimeOffset, seconds) + activatedAt := am.SaturatingUAdd(ArbitrumStartTime, seconds) return am.SaturatingUSub(time, activatedAt) } diff --git a/arbos/programs/testconstants.go b/arbos/programs/testconstants.go index 215b5fb8a7..1ab0e6e93b 100644 --- a/arbos/programs/testconstants.go +++ b/arbos/programs/testconstants.go @@ -1,6 +1,9 @@ // Copyright 2024, Offchain Labs, Inc. // For license information, see https://github.com/OffchainLabs/nitro/blob/master/LICENSE +//go:build !wasm +// +build !wasm + package programs // This file exists because cgo isn't allowed in tests diff --git a/broadcaster/broadcaster.go b/broadcaster/broadcaster.go index ac5c6c39da..ba95f2d8af 100644 --- a/broadcaster/broadcaster.go +++ b/broadcaster/broadcaster.go @@ -22,11 +22,6 @@ import ( "github.com/offchainlabs/nitro/wsbroadcastserver" ) -type MessageWithMetadataAndBlockHash struct { - Message arbostypes.MessageWithMetadata - BlockHash *common.Hash -} - type Broadcaster struct { server *wsbroadcastserver.WSBroadcastServer backlog backlog.Backlog @@ -98,7 +93,7 @@ func (b *Broadcaster) BroadcastSingleFeedMessage(bfm *m.BroadcastFeedMessage) { } func (b *Broadcaster) BroadcastMessages( - messagesWithBlockHash []MessageWithMetadataAndBlockHash, + messagesWithBlockHash []arbostypes.MessageWithMetadataAndBlockHash, seq arbutil.MessageIndex, ) (err error) { defer func() { @@ -109,7 +104,7 @@ func (b *Broadcaster) BroadcastMessages( }() var feedMessages []*m.BroadcastFeedMessage for i, msg := range messagesWithBlockHash { - bfm, err := b.NewBroadcastFeedMessage(msg.Message, seq+arbutil.MessageIndex(i), msg.BlockHash) + bfm, err := b.NewBroadcastFeedMessage(msg.MessageWithMeta, seq+arbutil.MessageIndex(i), msg.BlockHash) if err != nil { return err } diff --git a/cmd/staterecovery/staterecovery.go b/cmd/staterecovery/staterecovery.go index 6390826a91..58ad06ad14 100644 --- a/cmd/staterecovery/staterecovery.go +++ b/cmd/staterecovery/staterecovery.go @@ -31,7 +31,7 @@ func RecreateMissingStates(chainDb ethdb.Database, bc *core.BlockChain, cacheCon return fmt.Errorf("start block parent is missing, parent block number: %d", current-1) } hashConfig := *hashdb.Defaults - hashConfig.CleanCacheSize = cacheConfig.TrieCleanLimit + hashConfig.CleanCacheSize = cacheConfig.TrieCleanLimit * 1024 * 1024 trieConfig := &trie.Config{ Preimages: false, HashDB: &hashConfig, diff --git a/execution/gethexec/executionengine.go b/execution/gethexec/executionengine.go index 38569f44ab..96dca6c63e 100644 --- a/execution/gethexec/executionengine.go +++ b/execution/gethexec/executionengine.go @@ -116,7 +116,7 @@ func (s *ExecutionEngine) GetBatchFetcher() execution.BatchFetcher { return s.consensus } -func (s *ExecutionEngine) Reorg(count arbutil.MessageIndex, newMessages []arbostypes.MessageWithMetadata, oldMessages []*arbostypes.MessageWithMetadata) ([]*execution.MessageResult, error) { +func (s *ExecutionEngine) Reorg(count arbutil.MessageIndex, newMessages []arbostypes.MessageWithMetadataAndBlockHash, oldMessages []*arbostypes.MessageWithMetadata) ([]*execution.MessageResult, error) { if count == 0 { return nil, errors.New("cannot reorg out genesis") } @@ -149,9 +149,9 @@ func (s *ExecutionEngine) Reorg(count arbutil.MessageIndex, newMessages []arbost for i := range newMessages { var msgForPrefetch *arbostypes.MessageWithMetadata if i < len(newMessages)-1 { - msgForPrefetch = &newMessages[i] + msgForPrefetch = &newMessages[i].MessageWithMeta } - msgResult, err := s.digestMessageWithBlockMutex(count+arbutil.MessageIndex(i), &newMessages[i], msgForPrefetch) + msgResult, err := s.digestMessageWithBlockMutex(count+arbutil.MessageIndex(i), &newMessages[i].MessageWithMeta, msgForPrefetch) if err != nil { return nil, err } @@ -197,7 +197,7 @@ func (s *ExecutionEngine) NextDelayedMessageNumber() (uint64, error) { return currentHeader.Nonce.Uint64(), nil } -func messageFromTxes(header *arbostypes.L1IncomingMessageHeader, txes types.Transactions, txErrors []error) (*arbostypes.L1IncomingMessage, error) { +func MessageFromTxes(header *arbostypes.L1IncomingMessageHeader, txes types.Transactions, txErrors []error) (*arbostypes.L1IncomingMessage, error) { var l2Message []byte if len(txes) == 1 && txErrors[0] == nil { txBytes, err := txes[0].MarshalBinary() @@ -368,7 +368,7 @@ func (s *ExecutionEngine) sequenceTransactionsWithBlockMutex(header *arbostypes. return nil, nil } - msg, err := messageFromTxes(header, txes, hooks.TxErrors) + msg, err := MessageFromTxes(header, txes, hooks.TxErrors) if err != nil { return nil, err } diff --git a/execution/gethexec/node.go b/execution/gethexec/node.go index ae76b88530..458d6601c5 100644 --- a/execution/gethexec/node.go +++ b/execution/gethexec/node.go @@ -346,7 +346,7 @@ func (n *ExecutionNode) StopAndWait() { func (n *ExecutionNode) DigestMessage(num arbutil.MessageIndex, msg *arbostypes.MessageWithMetadata, msgForPrefetch *arbostypes.MessageWithMetadata) (*execution.MessageResult, error) { return n.ExecEngine.DigestMessage(num, msg, msgForPrefetch) } -func (n *ExecutionNode) Reorg(count arbutil.MessageIndex, newMessages []arbostypes.MessageWithMetadata, oldMessages []*arbostypes.MessageWithMetadata) ([]*execution.MessageResult, error) { +func (n *ExecutionNode) Reorg(count arbutil.MessageIndex, newMessages []arbostypes.MessageWithMetadataAndBlockHash, oldMessages []*arbostypes.MessageWithMetadata) ([]*execution.MessageResult, error) { return n.ExecEngine.Reorg(count, newMessages, oldMessages) } func (n *ExecutionNode) HeadMessageNumber() (arbutil.MessageIndex, error) { diff --git a/execution/interface.go b/execution/interface.go index d2a5b58fe5..66aefe9a5e 100644 --- a/execution/interface.go +++ b/execution/interface.go @@ -31,7 +31,7 @@ var ErrSequencerInsertLockTaken = errors.New("insert lock taken") // always needed type ExecutionClient interface { DigestMessage(num arbutil.MessageIndex, msg *arbostypes.MessageWithMetadata, msgForPrefetch *arbostypes.MessageWithMetadata) (*MessageResult, error) - Reorg(count arbutil.MessageIndex, newMessages []arbostypes.MessageWithMetadata, oldMessages []*arbostypes.MessageWithMetadata) ([]*MessageResult, error) + Reorg(count arbutil.MessageIndex, newMessages []arbostypes.MessageWithMetadataAndBlockHash, oldMessages []*arbostypes.MessageWithMetadata) ([]*MessageResult, error) HeadMessageNumber() (arbutil.MessageIndex, error) HeadMessageNumberSync(t *testing.T) (arbutil.MessageIndex, error) ResultAtPos(pos arbutil.MessageIndex) (*MessageResult, error) diff --git a/system_tests/seq_coordinator_test.go b/system_tests/seq_coordinator_test.go index 886a0528c7..43d55f40c9 100644 --- a/system_tests/seq_coordinator_test.go +++ b/system_tests/seq_coordinator_test.go @@ -8,12 +8,14 @@ import ( "errors" "fmt" "math/big" + "net" "testing" "time" "github.com/go-redis/redis/v8" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/log" "github.com/offchainlabs/nitro/arbnode" "github.com/offchainlabs/nitro/arbos/arbostypes" @@ -21,6 +23,7 @@ import ( "github.com/offchainlabs/nitro/execution" "github.com/offchainlabs/nitro/execution/gethexec" "github.com/offchainlabs/nitro/util/redisutil" + "github.com/offchainlabs/nitro/util/testhelpers" ) func initRedisForTest(t *testing.T, ctx context.Context, redisUrl string, nodeNames []string) { @@ -270,6 +273,8 @@ func TestRedisSeqCoordinatorPriorities(t *testing.T) { } func testCoordinatorMessageSync(t *testing.T, successCase bool) { + logHandler := testhelpers.InitTestLog(t, log.LvlTrace) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -304,16 +309,25 @@ func testCoordinatorMessageSync(t *testing.T, successCase bool) { nodeConfigDup := *builder.nodeConfig builder.nodeConfig = &nodeConfigDup - + builder.nodeConfig.Feed.Output = *newBroadcasterConfigTest() builder.nodeConfig.SeqCoordinator.MyUrl = nodeNames[1] if !successCase { builder.nodeConfig.SeqCoordinator.Signer.ECDSA.AcceptSequencer = false builder.nodeConfig.SeqCoordinator.Signer.ECDSA.AllowedAddresses = []string{builder.L2Info.GetAddress("User2").Hex()} } - testClientB, cleanupB := builder.Build2ndNode(t, &SecondNodeParams{nodeConfig: builder.nodeConfig}) defer cleanupB() + // Build nodeBOutputFeedReader. + // nodeB doesn't sequence transactions, but adds messages related to them to its output feed. + // nodeBOutputFeedReader reads those messages from this feed and processes them. + // nodeBOutputFeedReader doesn't read messages from L1 since none of the nodes posts to L1. + nodeBPort := testClientB.ConsensusNode.BroadcastServer.ListenerAddr().(*net.TCPAddr).Port + nodeConfigNodeBOutputFeedReader := arbnode.ConfigDefaultL1NonSequencerTest() + nodeConfigNodeBOutputFeedReader.Feed.Input = *newBroadcastClientConfigTest(nodeBPort) + testClientNodeBOutputFeedReader, cleanupNodeBOutputFeedReader := builder.Build2ndNode(t, &SecondNodeParams{nodeConfig: nodeConfigNodeBOutputFeedReader}) + defer cleanupNodeBOutputFeedReader() + tx := builder.L2Info.PrepareTx("Owner", "User2", builder.L2Info.TransferGas, big.NewInt(1e12), nil) err = builder.L2.Client.SendTransaction(ctx, tx) @@ -330,6 +344,19 @@ func testCoordinatorMessageSync(t *testing.T, successCase bool) { if l2balance.Cmp(big.NewInt(1e12)) != 0 { t.Fatal("Unexpected balance:", l2balance) } + + // check that nodeBOutputFeedReader also processed the transaction + _, err = WaitForTx(ctx, testClientNodeBOutputFeedReader.Client, tx.Hash(), time.Second*5) + Require(t, err) + l2balance, err = testClientNodeBOutputFeedReader.Client.BalanceAt(ctx, builder.L2Info.GetAddress("User2"), nil) + Require(t, err) + if l2balance.Cmp(big.NewInt(1e12)) != 0 { + t.Fatal("Unexpected balance:", l2balance) + } + + if logHandler.WasLogged(arbnode.BlockHashMismatchLogMsg) { + t.Fatal("BlockHashMismatchLogMsg was logged unexpectedly") + } } else { _, err = WaitForTx(ctx, testClientB.Client, tx.Hash(), time.Second) if err == nil { diff --git a/system_tests/seqfeed_test.go b/system_tests/seqfeed_test.go index 749a91e3b1..ab30598b60 100644 --- a/system_tests/seqfeed_test.go +++ b/system_tests/seqfeed_test.go @@ -11,10 +11,19 @@ import ( "testing" "time" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/log" "github.com/offchainlabs/nitro/arbnode" + "github.com/offchainlabs/nitro/arbos/arbostypes" + "github.com/offchainlabs/nitro/arbos/l1pricing" "github.com/offchainlabs/nitro/broadcastclient" + "github.com/offchainlabs/nitro/broadcaster/backlog" + "github.com/offchainlabs/nitro/broadcaster/message" + "github.com/offchainlabs/nitro/execution/gethexec" "github.com/offchainlabs/nitro/relay" "github.com/offchainlabs/nitro/util/signature" + "github.com/offchainlabs/nitro/util/testhelpers" "github.com/offchainlabs/nitro/wsbroadcastserver" ) @@ -38,7 +47,8 @@ func newBroadcastClientConfigTest(port int) *broadcastclient.Config { } func TestSequencerFeed(t *testing.T) { - t.Parallel() + logHandler := testhelpers.InitTestLog(t, log.LvlTrace) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -73,6 +83,10 @@ func TestSequencerFeed(t *testing.T) { if l2balance.Cmp(big.NewInt(1e12)) != 0 { t.Fatal("Unexpected balance:", l2balance) } + + if logHandler.WasLogged(arbnode.BlockHashMismatchLogMsg) { + t.Fatal("BlockHashMismatchLogMsg was logged unexpectedly") + } } func TestRelayedSequencerFeed(t *testing.T) { @@ -250,3 +264,101 @@ func TestLyingSequencer(t *testing.T) { func TestLyingSequencerLocalDAS(t *testing.T) { testLyingSequencer(t, "files") } + +func testBlockHashComparison(t *testing.T, blockHash *common.Hash, mustMismatch bool) { + logHandler := testhelpers.InitTestLog(t, log.LvlTrace) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + backlogConfiFetcher := func() *backlog.Config { + return &backlog.DefaultTestConfig + } + bklg := backlog.NewBacklog(backlogConfiFetcher) + + wsBroadcastServer := wsbroadcastserver.NewWSBroadcastServer( + newBroadcasterConfigTest, + bklg, + 412346, + nil, + ) + err := wsBroadcastServer.Initialize() + if err != nil { + t.Fatal("error initializing wsBroadcastServer:", err) + } + err = wsBroadcastServer.Start(ctx) + if err != nil { + t.Fatal("error starting wsBroadcastServer:", err) + } + defer wsBroadcastServer.StopAndWait() + + port := wsBroadcastServer.ListenerAddr().(*net.TCPAddr).Port + + builder := NewNodeBuilder(ctx).DefaultConfig(t, true) + builder.nodeConfig.Feed.Input = *newBroadcastClientConfigTest(port) + cleanup := builder.Build(t) + defer cleanup() + testClient := builder.L2 + + userAccount := "User2" + builder.L2Info.GenerateAccount(userAccount) + tx := builder.L2Info.PrepareTx("Owner", userAccount, builder.L2Info.TransferGas, big.NewInt(1e12), nil) + l1IncomingMsgHeader := arbostypes.L1IncomingMessageHeader{ + Kind: arbostypes.L1MessageType_L2Message, + Poster: l1pricing.BatchPosterAddress, + BlockNumber: 29, + Timestamp: 1715295980, + RequestId: nil, + L1BaseFee: nil, + } + l1IncomingMsg, err := gethexec.MessageFromTxes( + &l1IncomingMsgHeader, + types.Transactions{tx}, + []error{nil}, + ) + Require(t, err) + + broadcastMessage := message.BroadcastMessage{ + Version: 1, + Messages: []*message.BroadcastFeedMessage{ + { + SequenceNumber: 1, + Message: arbostypes.MessageWithMetadata{ + Message: l1IncomingMsg, + DelayedMessagesRead: 1, + }, + BlockHash: blockHash, + }, + }, + } + wsBroadcastServer.Broadcast(&broadcastMessage) + + // By now, even though block hash mismatch, the transaction should still be processed + _, err = WaitForTx(ctx, testClient.Client, tx.Hash(), time.Second*15) + if err != nil { + t.Fatal("error waiting for tx:", err) + } + l2balance, err := testClient.Client.BalanceAt(ctx, builder.L2Info.GetAddress(userAccount), nil) + if err != nil { + t.Fatal("error getting balance:", err) + } + if l2balance.Cmp(big.NewInt(1e12)) != 0 { + t.Fatal("Unexpected balance:", l2balance) + } + + mismatched := logHandler.WasLogged(arbnode.BlockHashMismatchLogMsg) + if mustMismatch && !mismatched { + t.Fatal("Failed to log BlockHashMismatchLogMsg") + } else if !mustMismatch && mismatched { + t.Fatal("BlockHashMismatchLogMsg was logged unexpectedly") + } +} + +func TestBlockHashFeedMismatch(t *testing.T) { + blockHash := common.HexToHash("0x1111111111111111111111111111111111111111111111111111111111111111") + testBlockHashComparison(t, &blockHash, true) +} + +func TestBlockHashFeedNil(t *testing.T) { + testBlockHashComparison(t, nil, false) +} diff --git a/util/arbmath/math.go b/util/arbmath/math.go index 1c11c6ad58..d7a0d1f523 100644 --- a/util/arbmath/math.go +++ b/util/arbmath/math.go @@ -74,14 +74,6 @@ func MaxInt[T Number](values ...T) T { return max } -// AbsValue the absolute value of a number -func AbsValue[T Number](value T) T { - if value < 0 { - return -value // never happens for unsigned types - } - return value -} - // Checks if two ints are sufficiently close to one another func Within[T Unsigned](a, b, bound T) bool { min := MinInt(a, b) @@ -267,14 +259,22 @@ func BigFloatMulByUint(multiplicand *big.Float, multiplier uint64) *big.Float { return new(big.Float).Mul(multiplicand, UintToBigFloat(multiplier)) } +func MaxSignedValue[T Signed]() T { + return T((uint64(1) << (8*unsafe.Sizeof(T(0)) - 1)) - 1) +} + +func MinSignedValue[T Signed]() T { + return T(uint64(1) << ((8 * unsafe.Sizeof(T(0))) - 1)) +} + // SaturatingAdd add two integers without overflow func SaturatingAdd[T Signed](a, b T) T { sum := a + b if b > 0 && sum < a { - sum = ^T(0) >> 1 + sum = MaxSignedValue[T]() } if b < 0 && sum > a { - sum = (^T(0) >> 1) + 1 + sum = MinSignedValue[T]() } return sum } @@ -290,7 +290,11 @@ func SaturatingUAdd[T Unsigned](a, b T) T { // SaturatingSub subtract an int64 from another without overflow func SaturatingSub(minuend, subtrahend int64) int64 { - return SaturatingAdd(minuend, -subtrahend) + if subtrahend == math.MinInt64 { + // The absolute value of MinInt64 is one greater than MaxInt64 + return SaturatingAdd(SaturatingAdd(minuend, math.MaxInt64), 1) + } + return SaturatingAdd(minuend, SaturatingNeg(subtrahend)) } // SaturatingUSub subtract an integer from another without underflow @@ -315,9 +319,9 @@ func SaturatingMul[T Signed](a, b T) T { product := a * b if b != 0 && product/b != a { if (a > 0 && b > 0) || (a < 0 && b < 0) { - product = ^T(0) >> 1 + product = MaxSignedValue[T]() } else { - product = (^T(0) >> 1) + 1 + product = MinSignedValue[T]() } } return product @@ -367,8 +371,8 @@ func SaturatingCastToUint(value *big.Int) uint64 { // Negates an int without underflow func SaturatingNeg[T Signed](value T) T { - if value == ^T(0) { - return (^T(0)) >> 1 + if value < 0 && value == MinSignedValue[T]() { + return MaxSignedValue[T]() } return -value } diff --git a/util/arbmath/math_fuzz_test.go b/util/arbmath/math_fuzz_test.go new file mode 100644 index 0000000000..591d699de0 --- /dev/null +++ b/util/arbmath/math_fuzz_test.go @@ -0,0 +1,112 @@ +// Copyright 2024, Offchain Labs, Inc. +// For license information, see https://github.com/nitro/blob/master/LICENSE + +package arbmath + +import ( + "math/big" + "testing" +) + +func toBig[T Signed](a T) *big.Int { + return big.NewInt(int64(a)) +} + +func saturatingBigToInt[T Signed](a *big.Int) T { + // MinSignedValue and MaxSignedValue are already separately tested + if a.Cmp(toBig(MaxSignedValue[T]())) > 0 { + return MaxSignedValue[T]() + } + if a.Cmp(toBig(MinSignedValue[T]())) < 0 { + return MinSignedValue[T]() + } + return T(a.Int64()) +} + +func fuzzSaturatingAdd[T Signed](f *testing.F) { + f.Fuzz(func(t *testing.T, a, b T) { + got := SaturatingAdd(a, b) + expected := saturatingBigToInt[T](new(big.Int).Add(toBig(a), toBig(b))) + if got != expected { + t.Errorf("SaturatingAdd(%v, %v) = %v, expected %v", a, b, got, expected) + } + }) +} + +func fuzzSaturatingMul[T Signed](f *testing.F) { + f.Fuzz(func(t *testing.T, a, b T) { + got := SaturatingMul(a, b) + expected := saturatingBigToInt[T](new(big.Int).Mul(toBig(a), toBig(b))) + if got != expected { + t.Errorf("SaturatingMul(%v, %v) = %v, expected %v", a, b, got, expected) + } + }) +} + +func fuzzSaturatingNeg[T Signed](f *testing.F) { + f.Fuzz(func(t *testing.T, a T) { + got := SaturatingNeg(a) + expected := saturatingBigToInt[T](new(big.Int).Neg(toBig(a))) + if got != expected { + t.Errorf("SaturatingNeg(%v) = %v, expected %v", a, got, expected) + } + }) +} + +func FuzzSaturatingAddInt8(f *testing.F) { + fuzzSaturatingAdd[int8](f) +} + +func FuzzSaturatingAddInt16(f *testing.F) { + fuzzSaturatingAdd[int16](f) +} + +func FuzzSaturatingAddInt32(f *testing.F) { + fuzzSaturatingAdd[int32](f) +} + +func FuzzSaturatingAddInt64(f *testing.F) { + fuzzSaturatingAdd[int64](f) +} + +func FuzzSaturatingSub(f *testing.F) { + f.Fuzz(func(t *testing.T, a, b int64) { + got := SaturatingSub(a, b) + expected := saturatingBigToInt[int64](new(big.Int).Sub(toBig(a), toBig(b))) + if got != expected { + t.Errorf("SaturatingSub(%v, %v) = %v, expected %v", a, b, got, expected) + } + }) +} + +func FuzzSaturatingMulInt8(f *testing.F) { + fuzzSaturatingMul[int8](f) +} + +func FuzzSaturatingMulInt16(f *testing.F) { + fuzzSaturatingMul[int16](f) +} + +func FuzzSaturatingMulInt32(f *testing.F) { + fuzzSaturatingMul[int32](f) +} + +func FuzzSaturatingMulInt64(f *testing.F) { + fuzzSaturatingMul[int64](f) +} + +func FuzzSaturatingNegInt8(f *testing.F) { + fuzzSaturatingNeg[int8](f) +} + +func FuzzSaturatingNegInt16(f *testing.F) { + fuzzSaturatingNeg[int16](f) +} + +func FuzzSaturatingNegInt32(f *testing.F) { + fuzzSaturatingNeg[int32](f) +} + +func FuzzSaturatingNegInt64(f *testing.F) { + fuzzSaturatingNeg[int64](f) +} diff --git a/util/arbmath/math_test.go b/util/arbmath/math_test.go index 2e2f14795a..1be60dc58b 100644 --- a/util/arbmath/math_test.go +++ b/util/arbmath/math_test.go @@ -5,6 +5,7 @@ package arbmath import ( "bytes" + "fmt" "math" "math/rand" "testing" @@ -120,6 +121,110 @@ func TestSlices(t *testing.T) { assert_eq(SliceWithRunoff(data, 7, 8), []uint8{}) } +func testMinMaxSignedValues[T Signed](t *testing.T, min T, max T) { + gotMin := MinSignedValue[T]() + if gotMin != min { + Fail(t, "expected min", min, "but got", gotMin) + } + gotMax := MaxSignedValue[T]() + if gotMax != max { + Fail(t, "expected max", max, "but got", gotMax) + } +} + +func TestMinMaxSignedValues(t *testing.T) { + testMinMaxSignedValues[int8](t, math.MinInt8, math.MaxInt8) + testMinMaxSignedValues[int16](t, math.MinInt16, math.MaxInt16) + testMinMaxSignedValues[int32](t, math.MinInt32, math.MaxInt32) + testMinMaxSignedValues[int64](t, math.MinInt64, math.MaxInt64) +} + +func TestSaturatingAdd(t *testing.T) { + tests := []struct { + a, b, expected int64 + }{ + {2, 3, 5}, + {-1, -2, -3}, + {math.MaxInt64, 1, math.MaxInt64}, + {math.MaxInt64, math.MaxInt64, math.MaxInt64}, + {math.MinInt64, -1, math.MinInt64}, + {math.MinInt64, math.MinInt64, math.MinInt64}, + } + + for _, tc := range tests { + t.Run(fmt.Sprintf("%v + %v = %v", tc.a, tc.b, tc.expected), func(t *testing.T) { + sum := SaturatingAdd(int64(tc.a), int64(tc.b)) + if sum != tc.expected { + t.Errorf("SaturatingAdd(%v, %v) = %v; want %v", tc.a, tc.b, sum, tc.expected) + } + }) + } +} + +func TestSaturatingSub(t *testing.T) { + tests := []struct { + a, b, expected int64 + }{ + {5, 3, 2}, + {-3, -2, -1}, + {math.MinInt64, 1, math.MinInt64}, + {math.MinInt64, -1, math.MinInt64 + 1}, + {math.MinInt64, math.MinInt64, 0}, + {0, math.MinInt64, math.MaxInt64}, + } + + for _, tc := range tests { + t.Run("", func(t *testing.T) { + sum := SaturatingSub(int64(tc.a), int64(tc.b)) + if sum != tc.expected { + t.Errorf("SaturatingSub(%v, %v) = %v; want %v", tc.a, tc.b, sum, tc.expected) + } + }) + } +} + +func TestSaturatingMul(t *testing.T) { + tests := []struct { + a, b, expected int64 + }{ + {5, 3, 15}, + {-3, -2, 6}, + {math.MaxInt64, 2, math.MaxInt64}, + {math.MinInt64, 2, math.MinInt64}, + } + + for _, tc := range tests { + t.Run(fmt.Sprintf("%v - %v = %v", tc.a, tc.b, tc.expected), func(t *testing.T) { + sum := SaturatingMul(int64(tc.a), int64(tc.b)) + if sum != tc.expected { + t.Errorf("SaturatingMul(%v, %v) = %v; want %v", tc.a, tc.b, sum, tc.expected) + } + }) + } +} + +func TestSaturatingNeg(t *testing.T) { + tests := []struct { + value int64 + expected int64 + }{ + {0, 0}, + {5, -5}, + {-5, 5}, + {math.MinInt64, math.MaxInt64}, + {math.MaxInt64, math.MinInt64 + 1}, + } + + for _, tc := range tests { + t.Run(fmt.Sprintf("-%v = %v", tc.value, tc.expected), func(t *testing.T) { + result := SaturatingNeg(tc.value) + if result != tc.expected { + t.Errorf("SaturatingNeg(%v) = %v: expected %v", tc.value, result, tc.expected) + } + }) + } +} + func Fail(t *testing.T, printables ...interface{}) { t.Helper() testhelpers.FailImpl(t, printables...)