diff --git a/arbnode/inbox_tracker.go b/arbnode/inbox_tracker.go index 72e4ba2887..51f74cbeb4 100644 --- a/arbnode/inbox_tracker.go +++ b/arbnode/inbox_tracker.go @@ -22,6 +22,7 @@ import ( "github.com/offchainlabs/nitro/arbstate" "github.com/offchainlabs/nitro/arbutil" "github.com/offchainlabs/nitro/broadcaster" + m "github.com/offchainlabs/nitro/broadcaster/message" "github.com/offchainlabs/nitro/staker" "github.com/offchainlabs/nitro/util/containers" ) @@ -240,7 +241,7 @@ func (t *InboxTracker) PopulateFeedBacklog(broadcastServer *broadcaster.Broadcas if err != nil { return fmt.Errorf("error getting tx streamer message count: %w", err) } - var feedMessages []*broadcaster.BroadcastFeedMessage + var feedMessages []*m.BroadcastFeedMessage for seqNum := startMessage; seqNum < messageCount; seqNum++ { message, err := t.txStreamer.GetMessage(seqNum) if err != nil { diff --git a/arbnode/transaction_streamer.go b/arbnode/transaction_streamer.go index 3cbad93c9a..2d6f7d589c 100644 --- a/arbnode/transaction_streamer.go +++ b/arbnode/transaction_streamer.go @@ -32,6 +32,7 @@ import ( "github.com/offchainlabs/nitro/arbos/arbostypes" "github.com/offchainlabs/nitro/arbutil" "github.com/offchainlabs/nitro/broadcaster" + m "github.com/offchainlabs/nitro/broadcaster/message" "github.com/offchainlabs/nitro/execution" "github.com/offchainlabs/nitro/staker" "github.com/offchainlabs/nitro/util/arbmath" @@ -426,7 +427,7 @@ func (s *TransactionStreamer) AddMessages(pos arbutil.MessageIndex, messagesAreC return s.AddMessagesAndEndBatch(pos, messagesAreConfirmed, messages, nil) } -func (s *TransactionStreamer) AddBroadcastMessages(feedMessages []*broadcaster.BroadcastFeedMessage) error { +func (s *TransactionStreamer) AddBroadcastMessages(feedMessages []*m.BroadcastFeedMessage) error { if len(feedMessages) == 0 { return nil } diff --git a/broadcastclient/broadcastclient.go b/broadcastclient/broadcastclient.go index e94daa463c..f27fc28fa0 100644 --- a/broadcastclient/broadcastclient.go +++ b/broadcastclient/broadcastclient.go @@ -27,7 +27,7 @@ import ( "github.com/ethereum/go-ethereum/metrics" "github.com/offchainlabs/nitro/arbutil" - "github.com/offchainlabs/nitro/broadcaster" + m "github.com/offchainlabs/nitro/broadcaster/message" "github.com/offchainlabs/nitro/util/contracts" "github.com/offchainlabs/nitro/util/signature" "github.com/offchainlabs/nitro/util/stopwaiter" @@ -117,7 +117,7 @@ var DefaultTestConfig = Config{ } type TransactionStreamerInterface interface { - AddBroadcastMessages(feedMessages []*broadcaster.BroadcastFeedMessage) error + AddBroadcastMessages(feedMessages []*m.BroadcastFeedMessage) error } type BroadcastClient struct { @@ -381,7 +381,7 @@ func (bc *BroadcastClient) startBackgroundReader(earlyFrameData io.Reader) { backoffDuration = bc.config().ReconnectInitialBackoff if msg != nil { - res := broadcaster.BroadcastMessage{} + res := m.BroadcastMessage{} err = json.Unmarshal(msg, &res) if err != nil { log.Error("error unmarshalling message", "msg", msg, "err", err) @@ -483,7 +483,7 @@ func (bc *BroadcastClient) StopAndWait() { } } -func (bc *BroadcastClient) isValidSignature(ctx context.Context, message *broadcaster.BroadcastFeedMessage) error { +func (bc *BroadcastClient) isValidSignature(ctx context.Context, message *m.BroadcastFeedMessage) error { if bc.config().Verify.Dangerous.AcceptMissing && bc.sigVerifier == nil { // Verifier disabled return nil diff --git a/broadcastclient/broadcastclient_test.go b/broadcastclient/broadcastclient_test.go index fa743d4229..b75bc44c11 100644 --- a/broadcastclient/broadcastclient_test.go +++ b/broadcastclient/broadcastclient_test.go @@ -23,6 +23,7 @@ import ( "github.com/offchainlabs/nitro/arbos/arbostypes" "github.com/offchainlabs/nitro/arbutil" "github.com/offchainlabs/nitro/broadcaster" + m "github.com/offchainlabs/nitro/broadcaster/message" "github.com/offchainlabs/nitro/util/contracts" "github.com/offchainlabs/nitro/util/signature" "github.com/offchainlabs/nitro/util/testhelpers" @@ -178,20 +179,20 @@ func TestInvalidSignature(t *testing.T) { } type dummyTransactionStreamer struct { - messageReceiver chan broadcaster.BroadcastFeedMessage + messageReceiver chan m.BroadcastFeedMessage chainId uint64 sequencerAddr *common.Address } func NewDummyTransactionStreamer(chainId uint64, sequencerAddr *common.Address) *dummyTransactionStreamer { return &dummyTransactionStreamer{ - messageReceiver: make(chan broadcaster.BroadcastFeedMessage), + messageReceiver: make(chan m.BroadcastFeedMessage), chainId: chainId, sequencerAddr: sequencerAddr, } } -func (ts *dummyTransactionStreamer) AddBroadcastMessages(feedMessages []*broadcaster.BroadcastFeedMessage) error { +func (ts *dummyTransactionStreamer) AddBroadcastMessages(feedMessages []*m.BroadcastFeedMessage) error { for _, feedMessage := range feedMessages { ts.messageReceiver <- *feedMessage } diff --git a/broadcastclients/broadcastclients.go b/broadcastclients/broadcastclients.go index 29f5e5192a..23c4bd7738 100644 --- a/broadcastclients/broadcastclients.go +++ b/broadcastclients/broadcastclients.go @@ -12,7 +12,7 @@ import ( "github.com/offchainlabs/nitro/arbutil" "github.com/offchainlabs/nitro/broadcastclient" - "github.com/offchainlabs/nitro/broadcaster" + m "github.com/offchainlabs/nitro/broadcaster/message" "github.com/offchainlabs/nitro/util/contracts" "github.com/offchainlabs/nitro/util/stopwaiter" ) @@ -25,14 +25,14 @@ const PRIMARY_FEED_UPTIME = time.Minute * 10 type Router struct { stopwaiter.StopWaiter - messageChan chan broadcaster.BroadcastFeedMessage + messageChan chan m.BroadcastFeedMessage confirmedSequenceNumberChan chan arbutil.MessageIndex forwardTxStreamer broadcastclient.TransactionStreamerInterface forwardConfirmationChan chan arbutil.MessageIndex } -func (r *Router) AddBroadcastMessages(feedMessages []*broadcaster.BroadcastFeedMessage) error { +func (r *Router) AddBroadcastMessages(feedMessages []*m.BroadcastFeedMessage) error { for _, feedMessage := range feedMessages { r.messageChan <- *feedMessage } @@ -67,7 +67,7 @@ func NewBroadcastClients( } newStandardRouter := func() *Router { return &Router{ - messageChan: make(chan broadcaster.BroadcastFeedMessage, ROUTER_QUEUE_SIZE), + messageChan: make(chan m.BroadcastFeedMessage, ROUTER_QUEUE_SIZE), confirmedSequenceNumberChan: make(chan arbutil.MessageIndex, ROUTER_QUEUE_SIZE), forwardTxStreamer: txStreamer, forwardConfirmationChan: confirmedSequenceNumberListener, @@ -152,7 +152,7 @@ func (bcs *BroadcastClients) Start(ctx context.Context) { defer stopSecondaryFeedTimer.Stop() defer primaryFeedIsDownTimer.Stop() - msgHandler := func(msg broadcaster.BroadcastFeedMessage, router *Router) error { + msgHandler := func(msg m.BroadcastFeedMessage, router *Router) error { if _, ok := recentFeedItemsNew[msg.SequenceNumber]; ok { return nil } @@ -160,7 +160,7 @@ func (bcs *BroadcastClients) Start(ctx context.Context) { return nil } recentFeedItemsNew[msg.SequenceNumber] = time.Now() - if err := router.forwardTxStreamer.AddBroadcastMessages([]*broadcaster.BroadcastFeedMessage{&msg}); err != nil { + if err := router.forwardTxStreamer.AddBroadcastMessages([]*m.BroadcastFeedMessage{&msg}); err != nil { return err } return nil diff --git a/broadcaster/backlog/backlog.go b/broadcaster/backlog/backlog.go new file mode 100644 index 0000000000..a1bd2c302f --- /dev/null +++ b/broadcaster/backlog/backlog.go @@ -0,0 +1,406 @@ +package backlog + +import ( + "errors" + "fmt" + "sync" + "sync/atomic" + + "github.com/ethereum/go-ethereum/log" + m "github.com/offchainlabs/nitro/broadcaster/message" + "github.com/offchainlabs/nitro/util/arbmath" + "github.com/offchainlabs/nitro/util/containers" +) + +var ( + errDropSegments = errors.New("remove previous segments from backlog") + errSequenceNumberSeen = errors.New("sequence number already present in backlog") + errOutOfBounds = errors.New("message not found in backlog") +) + +// Backlog defines the interface for backlog. +type Backlog interface { + Head() BacklogSegment + Append(*m.BroadcastMessage) error + Get(uint64, uint64) (*m.BroadcastMessage, error) + Count() uint64 + Lookup(uint64) (BacklogSegment, error) +} + +// backlog stores backlogSegments and provides the ability to read/write +// messages. +type backlog struct { + head atomic.Pointer[backlogSegment] + tail atomic.Pointer[backlogSegment] + lookupByIndex *containers.SyncMap[uint64, *backlogSegment] + config ConfigFetcher + messageCount atomic.Uint64 +} + +// NewBacklog creates a backlog. +func NewBacklog(c ConfigFetcher) Backlog { + lookup := &containers.SyncMap[uint64, *backlogSegment]{} + return &backlog{ + lookupByIndex: lookup, + config: c, + } +} + +// Head return the head backlogSegment within the backlog. +func (b *backlog) Head() BacklogSegment { + return b.head.Load() +} + +// Append will add the given messages to the backlogSegment at head until +// that segment reaches its limit. If messages remain to be added a new segment +// will be created. +func (b *backlog) Append(bm *m.BroadcastMessage) error { + + if bm.ConfirmedSequenceNumberMessage != nil { + b.delete(uint64(bm.ConfirmedSequenceNumberMessage.SequenceNumber)) + } + + for _, msg := range bm.Messages { + segment := b.tail.Load() + if segment == nil { + segment = newBacklogSegment() + b.head.Store(segment) + b.tail.Store(segment) + } + + prevMsgIdx := segment.End() + if segment.count() >= b.config().SegmentLimit { + nextSegment := newBacklogSegment() + segment.nextSegment.Store(nextSegment) + prevMsgIdx = segment.End() + nextSegment.previousSegment.Store(segment) + segment = nextSegment + b.tail.Store(segment) + } + + err := segment.append(prevMsgIdx, msg) + if errors.Is(err, errDropSegments) { + head := b.head.Load() + b.removeFromLookup(head.Start(), uint64(msg.SequenceNumber)) + b.head.Store(segment) + b.tail.Store(segment) + b.messageCount.Store(0) + log.Warn(err.Error()) + } else if errors.Is(err, errSequenceNumberSeen) { + log.Info("ignoring message sequence number, already in backlog", "message sequence number", msg.SequenceNumber) + continue + } else if err != nil { + return err + } + b.lookupByIndex.Store(uint64(msg.SequenceNumber), segment) + b.messageCount.Add(1) + } + + return nil +} + +// Get reads messages from the given start to end MessageIndex. +func (b *backlog) Get(start, end uint64) (*m.BroadcastMessage, error) { + head := b.head.Load() + tail := b.tail.Load() + if head == nil && tail == nil { + return nil, errOutOfBounds + } + + if end > tail.End() { + return nil, errOutOfBounds + } + + segment, err := b.Lookup(start) + head = b.head.Load() + if start < head.Start() { + // doing this check after the Lookup call ensures there is no race + // condition with a delete call + start = head.Start() + segment = head + } else if err != nil { + return nil, err + } + + bm := &m.BroadcastMessage{Version: 1} + required := int(end-start) + 1 + for { + segMsgs, err := segment.Get(arbmath.MaxInt(start, segment.Start()), arbmath.MinInt(end, segment.End())) + if err != nil { + return nil, err + } + + bm.Messages = append(bm.Messages, segMsgs...) + segment = segment.Next() + if len(bm.Messages) == required { + break + } else if segment == nil { + return nil, errOutOfBounds + } + } + return bm, nil +} + +// delete removes segments before the confirmed sequence number given. The +// segment containing the confirmed sequence number will continue to store +// previous messages but will register that messages up to the given number +// have been deleted. +func (b *backlog) delete(confirmed uint64) { + head := b.head.Load() + tail := b.tail.Load() + if head == nil && tail == nil { + return + } + + if confirmed < head.Start() { + return + } + + if confirmed > tail.End() { + log.Error("confirmed sequence number is past the end of stored messages", "confirmed sequence number", confirmed, "last stored sequence number", tail.End()) + b.reset() + return + } + + // find the segment containing the confirmed message + found, err := b.Lookup(confirmed) + if err != nil { + log.Error(fmt.Sprintf("%s: clearing backlog", err.Error())) + b.reset() + return + } + segment, ok := found.(*backlogSegment) + if !ok { + log.Error("error in backlogSegment type assertion: clearing backlog") + b.reset() + return + } + + // delete messages from the segment with the confirmed message + newHead := segment + start := head.Start() + if segment.End() == confirmed { + found = segment.Next() + newHead, ok = found.(*backlogSegment) + if !ok { + log.Error("error in backlogSegment type assertion: clearing backlog") + b.reset() + return + } + } else { + err = segment.delete(confirmed) + if err != nil { + log.Error(fmt.Sprintf("%s: clearing backlog", err.Error())) + b.reset() + return + } + } + + // tidy up lookup, count and head + b.removeFromLookup(start, confirmed) + count := b.Count() + start - confirmed - uint64(1) + b.messageCount.Store(count) + b.head.Store(newHead) +} + +// removeFromLookup removes all entries from the head segment's start index to +// the given confirmed index. +func (b *backlog) removeFromLookup(start, end uint64) { + for i := start; i <= end; i++ { + b.lookupByIndex.Delete(i) + } +} + +// Lookup attempts to find the backlogSegment storing the given message index. +func (b *backlog) Lookup(i uint64) (BacklogSegment, error) { + segment, ok := b.lookupByIndex.Load(i) + if !ok { + return nil, fmt.Errorf("error finding backlog segment containing message with SequenceNumber %d", i) + } + + return segment, nil +} + +// Count returns the number of messages stored within the backlog. +func (s *backlog) Count() uint64 { + return s.messageCount.Load() +} + +// reset removes all segments from the backlog. +func (b *backlog) reset() { + b.head = atomic.Pointer[backlogSegment]{} + b.tail = atomic.Pointer[backlogSegment]{} + b.lookupByIndex = &containers.SyncMap[uint64, *backlogSegment]{} + b.messageCount.Store(0) +} + +// BacklogSegment defines the interface for backlogSegment. +type BacklogSegment interface { + Start() uint64 + End() uint64 + Next() BacklogSegment + Contains(uint64) bool + Messages() []*m.BroadcastFeedMessage + Get(uint64, uint64) ([]*m.BroadcastFeedMessage, error) +} + +// backlogSegment stores messages up to a limit defined by the backlog. It also +// points to the next backlogSegment in the list. +type backlogSegment struct { + messagesLock sync.RWMutex + messages []*m.BroadcastFeedMessage + nextSegment atomic.Pointer[backlogSegment] + previousSegment atomic.Pointer[backlogSegment] +} + +// newBacklogSegment creates a backlogSegment object with an empty slice of +// messages. It does not return an interface as it is only used inside the +// backlog library. +func newBacklogSegment() *backlogSegment { + return &backlogSegment{ + messages: []*m.BroadcastFeedMessage{}, + } +} + +// IsBacklogSegmentNil uses the internal backlogSegment type to check if a +// variable of type BacklogSegment is nil or not. Comparing whether an +// interface is nil directly will not work. +func IsBacklogSegmentNil(segment BacklogSegment) bool { + if segment == nil { + return true + } else if segment.(*backlogSegment) == nil { + return true + } + return false +} + +// Start returns the first message index within the backlogSegment. +func (s *backlogSegment) Start() uint64 { + s.messagesLock.RLock() + defer s.messagesLock.RUnlock() + return s.start() +} + +// start allows the first message to be retrieved from functions that already +// have the messagesLock. +func (s *backlogSegment) start() uint64 { + if len(s.messages) > 0 { + return uint64(s.messages[0].SequenceNumber) + } + return uint64(0) +} + +// End returns the last message index within the backlogSegment. +func (s *backlogSegment) End() uint64 { + s.messagesLock.RLock() + defer s.messagesLock.RUnlock() + return s.end() +} + +// end allows the first message to be retrieved from functions that already +// have the messagesLock. +func (s *backlogSegment) end() uint64 { + c := len(s.messages) + if c == 0 { + return uint64(0) + } + return uint64(s.messages[c-1].SequenceNumber) +} + +// Next returns the next backlogSegment. +func (s *backlogSegment) Next() BacklogSegment { + next := s.nextSegment.Load() + if next == nil { + return nil // return a nil interface instead of a nil *backlogSegment + } + return next +} + +// Messages returns all of the messages stored in the backlogSegment. +func (s *backlogSegment) Messages() []*m.BroadcastFeedMessage { + s.messagesLock.RLock() + defer s.messagesLock.RUnlock() + tmp := make([]*m.BroadcastFeedMessage, len(s.messages)) + copy(tmp, s.messages) + return tmp +} + +// Get reads messages from the given start to end message index. +func (s *backlogSegment) Get(start, end uint64) ([]*m.BroadcastFeedMessage, error) { + s.messagesLock.RLock() + defer s.messagesLock.RUnlock() + noMsgs := []*m.BroadcastFeedMessage{} + if start < s.start() { + return noMsgs, errOutOfBounds + } + + if end > s.end() { + return noMsgs, errOutOfBounds + } + + startIndex := start - s.start() + endIndex := end - s.start() + 1 + + tmp := make([]*m.BroadcastFeedMessage, endIndex-startIndex) + copy(tmp, s.messages[startIndex:endIndex]) + return tmp, nil +} + +// append appends the given BroadcastFeedMessage to messages if it is the first +// message in the sequence or the next in the sequence. If segment's end +// message is ahead of the given message append will do nothing. If the given +// message is ahead of the segment's end message append will return +// errDropSegments to ensure any messages before the given message are dropped. +func (s *backlogSegment) append(prevMsgIdx uint64, msg *m.BroadcastFeedMessage) error { + s.messagesLock.Lock() + defer s.messagesLock.Unlock() + + if expSeqNum := prevMsgIdx + 1; prevMsgIdx == 0 || uint64(msg.SequenceNumber) == expSeqNum { + s.messages = append(s.messages, msg) + } else if uint64(msg.SequenceNumber) > expSeqNum { + s.messages = nil + s.messages = append(s.messages, msg) + return fmt.Errorf("new message sequence number (%d) is greater than the expected sequence number (%d): %w", msg.SequenceNumber, expSeqNum, errDropSegments) + } else { + return errSequenceNumberSeen + } + return nil +} + +// Contains confirms whether the segment contains a message with the given +// sequence number. +func (s *backlogSegment) Contains(i uint64) bool { + s.messagesLock.RLock() + defer s.messagesLock.RUnlock() + start := s.start() + if i < start || i > s.end() { + return false + } + + msgIndex := i - start + msg := s.messages[msgIndex] + return uint64(msg.SequenceNumber) == i +} + +// delete removes messages from the backlogSegment up to and including the +// given confirmed message index. +func (s *backlogSegment) delete(confirmed uint64) error { + start := s.Start() + end := s.End() + msgIndex := confirmed - start + if !s.Contains(confirmed) { + return fmt.Errorf("confirmed message (%d) is not in expected index (%d) in current backlog (%d-%d)", confirmed, msgIndex, start, end) + } + + s.messagesLock.Lock() + s.messages = s.messages[msgIndex+1:] + s.messagesLock.Unlock() + return nil +} + +// count returns the number of messages stored in the backlog segment. +func (s *backlogSegment) count() int { + s.messagesLock.RLock() + defer s.messagesLock.RUnlock() + return len(s.messages) +} diff --git a/broadcaster/backlog/backlog_test.go b/broadcaster/backlog/backlog_test.go new file mode 100644 index 0000000000..ab25a523f7 --- /dev/null +++ b/broadcaster/backlog/backlog_test.go @@ -0,0 +1,461 @@ +package backlog + +import ( + "errors" + "sync" + "testing" + "time" + + "github.com/offchainlabs/nitro/arbutil" + m "github.com/offchainlabs/nitro/broadcaster/message" + "github.com/offchainlabs/nitro/util/arbmath" + "github.com/offchainlabs/nitro/util/containers" +) + +func validateBacklog(t *testing.T, b *backlog, count, start, end uint64, lookupKeys []arbutil.MessageIndex) { + if b.Count() != count { + t.Errorf("backlog message count (%d) does not equal expected message count (%d)", b.Count(), count) + } + + head := b.head.Load() + if start != 0 && head.Start() != start { + t.Errorf("head of backlog (%d) does not equal expected head (%d)", head.Start(), start) + } + + tail := b.tail.Load() + if end != 0 && tail.End() != end { + t.Errorf("tail of backlog (%d) does not equal expected tail (%d)", tail.End(), end) + } + + for _, k := range lookupKeys { + if _, err := b.Lookup(uint64(k)); err != nil { + t.Errorf("failed to find message (%d) in lookup", k) + } + } + + expLen := len(lookupKeys) + actualLen := int(b.Count()) + if expLen != actualLen { + t.Errorf("expected length of lookupByIndex map (%d) does not equal actual length (%d)", expLen, actualLen) + } +} + +func validateBroadcastMessage(t *testing.T, bm *m.BroadcastMessage, expectedCount int, start, end uint64) { + actualCount := len(bm.Messages) + if actualCount != expectedCount { + t.Errorf("number of messages returned (%d) does not equal the expected number of messages (%d)", actualCount, expectedCount) + } + + s := arbmath.MaxInt(start, 40) + for i := s; i <= end; i++ { + msg := bm.Messages[i-s] + if uint64(msg.SequenceNumber) != i { + t.Errorf("unexpected sequence number (%d) in %d returned message", i, i-s) + } + } +} + +func createDummyBacklog(indexes []arbutil.MessageIndex) (*backlog, error) { + b := &backlog{ + lookupByIndex: &containers.SyncMap[uint64, *backlogSegment]{}, + config: func() *Config { return &DefaultTestConfig }, + } + bm := &m.BroadcastMessage{Messages: m.CreateDummyBroadcastMessages(indexes)} + err := b.Append(bm) + return b, err +} + +func TestAppend(t *testing.T) { + testcases := []struct { + name string + backlogIndexes []arbutil.MessageIndex + newIndexes []arbutil.MessageIndex + expectedCount uint64 + expectedStart uint64 + expectedEnd uint64 + expectedLookupKeys []arbutil.MessageIndex + }{ + { + "EmptyBacklog", + []arbutil.MessageIndex{}, + []arbutil.MessageIndex{40, 41, 42, 43, 44, 45, 46}, + 7, + 40, + 46, + []arbutil.MessageIndex{40, 41, 42, 43, 44, 45, 46}, + }, + { + "NonEmptyBacklog", + []arbutil.MessageIndex{40, 41}, + []arbutil.MessageIndex{42, 43, 44, 45, 46}, + 7, + 40, + 46, + []arbutil.MessageIndex{40, 41, 42, 43, 44, 45, 46}, + }, + { + "NonSequential", + []arbutil.MessageIndex{40, 41}, + []arbutil.MessageIndex{42, 43, 45, 46}, + 2, // Message 45 is non sequential, the previous messages will be dropped from the backlog + 45, + 46, + []arbutil.MessageIndex{45, 46}, + }, + { + "MessageSeen", + []arbutil.MessageIndex{40, 41}, + []arbutil.MessageIndex{42, 43, 44, 45, 46, 41}, + 7, // Message 41 is already present in the backlog, it will be ignored + 40, + 46, + []arbutil.MessageIndex{40, 41, 42, 43, 44, 45, 46}, + }, + { + "NonSequentialFirstSegmentMessage", + []arbutil.MessageIndex{40, 41}, + []arbutil.MessageIndex{42, 44, 45, 46}, + 3, // Message 44 is non sequential and the first message in a new segment, the previous messages will be dropped from the backlog + 44, + 46, + []arbutil.MessageIndex{44, 45, 46}, + }, + { + "MessageSeenFirstSegmentMessage", + []arbutil.MessageIndex{40, 41}, + []arbutil.MessageIndex{42, 43, 44, 45, 41, 46}, + 7, // Message 41 is already present in the backlog and the first message in a new segment, it will be ignored + 40, + 46, + []arbutil.MessageIndex{40, 41, 42, 43, 44, 45, 46}, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + // The segment limit is 3, the above test cases have been created + // to include testing certain actions on the first message of a + // new segment. + b, err := createDummyBacklog(tc.backlogIndexes) + if err != nil { + t.Fatalf("error creating dummy backlog: %s", err) + } + + bm := &m.BroadcastMessage{Messages: m.CreateDummyBroadcastMessages(tc.newIndexes)} + err = b.Append(bm) + if err != nil { + t.Fatalf("error appending BroadcastMessage: %s", err) + } + + validateBacklog(t, b, tc.expectedCount, tc.expectedStart, tc.expectedEnd, tc.expectedLookupKeys) + }) + } +} + +func TestDeleteInvalidBacklog(t *testing.T) { + // Create a backlog with an invalid sequence + s := &backlogSegment{ + messages: m.CreateDummyBroadcastMessages([]arbutil.MessageIndex{40, 42}), + } + + lookup := &containers.SyncMap[uint64, *backlogSegment]{} + lookup.Store(40, s) + b := &backlog{ + lookupByIndex: lookup, + config: func() *Config { return &DefaultTestConfig }, + } + b.messageCount.Store(2) + b.head.Store(s) + b.tail.Store(s) + + bm := &m.BroadcastMessage{ + Messages: nil, + ConfirmedSequenceNumberMessage: &m.ConfirmedSequenceNumberMessage{ + SequenceNumber: 41, + }, + } + + err := b.Append(bm) + if err != nil { + t.Fatalf("error appending BroadcastMessage: %s", err) + } + + validateBacklog(t, b, 0, 0, 0, []arbutil.MessageIndex{}) +} + +func TestDelete(t *testing.T) { + testcases := []struct { + name string + backlogIndexes []arbutil.MessageIndex + confirmed arbutil.MessageIndex + expectedCount uint64 + expectedStart uint64 + expectedEnd uint64 + expectedLookupKeys []arbutil.MessageIndex + }{ + { + "EmptyBacklog", + []arbutil.MessageIndex{}, + 0, // no segements in backlog so nothing to delete + 0, + 0, + 0, + []arbutil.MessageIndex{}, + }, + { + "MsgBeforeBacklog", + []arbutil.MessageIndex{40, 41, 42, 43, 44, 45, 46}, + 39, // no segments will be deleted + 7, + 40, + 46, + []arbutil.MessageIndex{40, 41, 42, 43, 44, 45, 46}, + }, + { + "FirstMsgInBacklog", + []arbutil.MessageIndex{40, 41, 42, 43, 44, 45, 46}, + 40, // this is the first message in the backlog + 6, + 41, + 46, + []arbutil.MessageIndex{41, 42, 43, 44, 45, 46}, + }, + { + "FirstMsgInSegment", + []arbutil.MessageIndex{40, 41, 42, 43, 44, 45, 46}, + 43, // this is the first message in a middle segment of the backlog + 3, + 44, + 46, + []arbutil.MessageIndex{44, 45, 46}, + }, + { + "MiddleMsgInSegment", + []arbutil.MessageIndex{40, 41, 42, 43, 44, 45, 46}, + 44, // this is a message in the middle of a middle segment of the backlog + 2, + 45, + 46, + []arbutil.MessageIndex{45, 46}, + }, + { + "LastMsgInSegment", + []arbutil.MessageIndex{40, 41, 42, 43, 44, 45, 46}, + 45, // this is the last message in a middle segment of the backlog, the whole segment should be deleted along with any segments before it + 1, + 46, + 46, + []arbutil.MessageIndex{46}, + }, + { + "MsgAfterBacklog", + []arbutil.MessageIndex{40, 41, 42, 43, 44, 45, 46}, + 47, // all segments will be deleted + 0, + 0, + 0, + []arbutil.MessageIndex{}, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + b, err := createDummyBacklog(tc.backlogIndexes) + if err != nil { + t.Fatalf("error creating dummy backlog: %s", err) + } + + bm := &m.BroadcastMessage{ + Messages: nil, + ConfirmedSequenceNumberMessage: &m.ConfirmedSequenceNumberMessage{ + SequenceNumber: tc.confirmed, + }, + } + err = b.Append(bm) + if err != nil { + t.Fatalf("error appending BroadcastMessage: %s", err) + } + + validateBacklog(t, b, tc.expectedCount, tc.expectedStart, tc.expectedEnd, tc.expectedLookupKeys) + }) + } +} + +// make sure that an append, then delete, then append ends up with the correct messageCounts + +func TestGetEmptyBacklog(t *testing.T) { + b, err := createDummyBacklog([]arbutil.MessageIndex{}) + if err != nil { + t.Fatalf("error creating dummy backlog: %s", err) + } + + _, err = b.Get(1, 2) + if !errors.Is(err, errOutOfBounds) { + t.Fatalf("unexpected error: %s", err) + } +} + +func TestGet(t *testing.T) { + indexes := []arbutil.MessageIndex{40, 41, 42, 43, 44, 45, 46} + b, err := createDummyBacklog(indexes) + if err != nil { + t.Fatalf("error creating dummy backlog: %s", err) + } + + testcases := []struct { + name string + start uint64 + end uint64 + expectedErr error + expectedCount int + }{ + { + "LowerBoundFar", + 0, + 43, + nil, + 4, + }, + { + "LowerBoundClose", + 39, + 43, + nil, + 4, + }, + { + "UpperBoundFar", + 43, + 18446744073709551615, + errOutOfBounds, + 0, + }, + { + "UpperBoundClose", + 0, + 47, + errOutOfBounds, + 0, + }, + { + "AllMessages", + 40, + 46, + nil, + 7, + }, + { + "SomeMessages", + 42, + 44, + nil, + 3, + }, + { + "FirstMessage", + 40, + 40, + nil, + 1, + }, + { + "LastMessage", + 46, + 46, + nil, + 1, + }, + { + "SingleMessage", + 43, + 43, + nil, + 1, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + bm, err := b.Get(tc.start, tc.end) + if !errors.Is(err, tc.expectedErr) { + t.Fatalf("unexpected error: %s", err) + } + + // Some of the tests are checking the correct error is returned + // Do not check bm if an error should be returned + if tc.expectedErr == nil { + validateBroadcastMessage(t, bm, tc.expectedCount, tc.start, tc.end) + } + }) + } +} + +// TestBacklogRaceCondition performs read & write operations in separate +// goroutines to ensure that the backlog does not have race conditions. The +// `go test -race` command can be used to test this. +func TestBacklogRaceCondition(t *testing.T) { + indexes := []arbutil.MessageIndex{40, 41, 42, 43, 44, 45, 46} + b, err := createDummyBacklog(indexes) + if err != nil { + t.Fatalf("error creating dummy backlog: %s", err) + } + + wg := sync.WaitGroup{} + newIndexes := []arbutil.MessageIndex{47, 48, 49, 50, 51, 52, 53, 54, 55} + + // Write to backlog in goroutine + wg.Add(1) + errs := make(chan error, 15) + go func(t *testing.T, b *backlog) { + defer wg.Done() + for _, i := range newIndexes { + bm := m.CreateDummyBroadcastMessage([]arbutil.MessageIndex{i}) + err := b.Append(bm) + errs <- err + if err != nil { + return + } + time.Sleep(time.Millisecond) + } + }(t, b) + + // Read from backlog in goroutine + wg.Add(1) + go func(t *testing.T, b *backlog) { + defer wg.Done() + for _, i := range []uint64{42, 43, 44, 45, 46, 47} { + bm, err := b.Get(i, i+1) + errs <- err + if err != nil { + return + } else { + validateBroadcastMessage(t, bm, 2, i, i+1) + } + time.Sleep(2 * time.Millisecond) + } + }(t, b) + + // Delete from backlog in goroutine. This is normally done via Append with + // a confirmed sequence number, using delete method for simplicity in test. + wg.Add(1) + go func(t *testing.T, b *backlog) { + defer wg.Done() + for _, i := range []uint64{40, 43, 47} { + b.delete(i) + time.Sleep(10 * time.Millisecond) + } + }(t, b) + + // Wait for all goroutines to finish or return errors + wg.Wait() + close(errs) + for err = range errs { + + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + } + // Messages up to 47 were deleted. However the segment that 47 was in is + // kept, which is why the backlog starts at 46. + validateBacklog(t, b, 8, 48, 55, newIndexes[1:]) +} diff --git a/broadcaster/backlog/config.go b/broadcaster/backlog/config.go new file mode 100644 index 0000000000..0e760cd0cf --- /dev/null +++ b/broadcaster/backlog/config.go @@ -0,0 +1,24 @@ +package backlog + +import ( + flag "github.com/spf13/pflag" +) + +type ConfigFetcher func() *Config + +type Config struct { + SegmentLimit int `koanf:"segment-limit" reload:"hot"` +} + +func AddOptions(prefix string, f *flag.FlagSet) { + f.Int(prefix+".segment-limit", DefaultConfig.SegmentLimit, "the maximum number of messages each segment within the backlog can contain") +} + +var ( + DefaultConfig = Config{ + SegmentLimit: 240, + } + DefaultTestConfig = Config{ + SegmentLimit: 3, + } +) diff --git a/broadcaster/broadcaster.go b/broadcaster/broadcaster.go index c3f4c62ce0..8a70e39810 100644 --- a/broadcaster/broadcaster.go +++ b/broadcaster/broadcaster.go @@ -9,68 +9,34 @@ import ( "github.com/gobwas/ws" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/log" "github.com/offchainlabs/nitro/arbos/arbostypes" "github.com/offchainlabs/nitro/arbutil" + "github.com/offchainlabs/nitro/broadcaster/backlog" + m "github.com/offchainlabs/nitro/broadcaster/message" "github.com/offchainlabs/nitro/util/signature" "github.com/offchainlabs/nitro/wsbroadcastserver" ) type Broadcaster struct { - server *wsbroadcastserver.WSBroadcastServer - catchupBuffer *SequenceNumberCatchupBuffer - chainId uint64 - dataSigner signature.DataSignerFunc -} - -// BroadcastMessage is the base message type for messages to send over the network. -// -// Acts as a variant holding the message types. The type of the message is -// indicated by whichever of the fields is non-empty. The fields holding the message -// types are annotated with omitempty so only the populated message is sent as -// json. The message fields should be pointers or slices and end with -// "Messages" or "Message". -// -// The format is forwards compatible, ie if a json BroadcastMessage is received that -// has fields that are not in the Go struct then deserialization will succeed -// skip the unknown field [1] -// -// References: -// [1] https://pkg.go.dev/encoding/json#Unmarshal -type BroadcastMessage struct { - Version int `json:"version"` - // TODO better name than messages since there are different types of messages - Messages []*BroadcastFeedMessage `json:"messages,omitempty"` - ConfirmedSequenceNumberMessage *ConfirmedSequenceNumberMessage `json:"confirmedSequenceNumberMessage,omitempty"` -} - -type BroadcastFeedMessage struct { - SequenceNumber arbutil.MessageIndex `json:"sequenceNumber"` - Message arbostypes.MessageWithMetadata `json:"message"` - Signature []byte `json:"signature"` -} - -func (m *BroadcastFeedMessage) Hash(chainId uint64) (common.Hash, error) { - return m.Message.Hash(m.SequenceNumber, chainId) -} - -type ConfirmedSequenceNumberMessage struct { - SequenceNumber arbutil.MessageIndex `json:"sequenceNumber"` + server *wsbroadcastserver.WSBroadcastServer + backlog backlog.Backlog + chainId uint64 + dataSigner signature.DataSignerFunc } func NewBroadcaster(config wsbroadcastserver.BroadcasterConfigFetcher, chainId uint64, feedErrChan chan error, dataSigner signature.DataSignerFunc) *Broadcaster { - catchupBuffer := NewSequenceNumberCatchupBuffer(func() bool { return config().LimitCatchup }, func() int { return config().MaxCatchup }) + bklg := backlog.NewBacklog(func() *backlog.Config { return &config().Backlog }) return &Broadcaster{ - server: wsbroadcastserver.NewWSBroadcastServer(config, catchupBuffer, chainId, feedErrChan), - catchupBuffer: catchupBuffer, - chainId: chainId, - dataSigner: dataSigner, + server: wsbroadcastserver.NewWSBroadcastServer(config, bklg, chainId, feedErrChan), + backlog: bklg, + chainId: chainId, + dataSigner: dataSigner, } } -func (b *Broadcaster) NewBroadcastFeedMessage(message arbostypes.MessageWithMetadata, sequenceNumber arbutil.MessageIndex) (*BroadcastFeedMessage, error) { +func (b *Broadcaster) NewBroadcastFeedMessage(message arbostypes.MessageWithMetadata, sequenceNumber arbutil.MessageIndex) (*m.BroadcastFeedMessage, error) { var messageSignature []byte if b.dataSigner != nil { hash, err := message.Hash(sequenceNumber, b.chainId) @@ -83,7 +49,7 @@ func (b *Broadcaster) NewBroadcastFeedMessage(message arbostypes.MessageWithMeta } } - return &BroadcastFeedMessage{ + return &m.BroadcastFeedMessage{ SequenceNumber: sequenceNumber, Message: message, Signature: messageSignature, @@ -105,17 +71,17 @@ func (b *Broadcaster) BroadcastSingle(msg arbostypes.MessageWithMetadata, seq ar return nil } -func (b *Broadcaster) BroadcastSingleFeedMessage(bfm *BroadcastFeedMessage) { - broadcastFeedMessages := make([]*BroadcastFeedMessage, 0, 1) +func (b *Broadcaster) BroadcastSingleFeedMessage(bfm *m.BroadcastFeedMessage) { + broadcastFeedMessages := make([]*m.BroadcastFeedMessage, 0, 1) broadcastFeedMessages = append(broadcastFeedMessages, bfm) b.BroadcastFeedMessages(broadcastFeedMessages) } -func (b *Broadcaster) BroadcastFeedMessages(messages []*BroadcastFeedMessage) { +func (b *Broadcaster) BroadcastFeedMessages(messages []*m.BroadcastFeedMessage) { - bm := BroadcastMessage{ + bm := &m.BroadcastMessage{ Version: 1, Messages: messages, } @@ -125,9 +91,12 @@ func (b *Broadcaster) BroadcastFeedMessages(messages []*BroadcastFeedMessage) { func (b *Broadcaster) Confirm(seq arbutil.MessageIndex) { log.Debug("confirming sequence number", "sequenceNumber", seq) - b.server.Broadcast(BroadcastMessage{ - Version: 1, - ConfirmedSequenceNumberMessage: &ConfirmedSequenceNumberMessage{seq}}) + b.server.Broadcast(&m.BroadcastMessage{ + Version: 1, + ConfirmedSequenceNumberMessage: &m.ConfirmedSequenceNumberMessage{ + SequenceNumber: seq, + }, + }) } func (b *Broadcaster) ClientCount() int32 { @@ -139,7 +108,7 @@ func (b *Broadcaster) ListenerAddr() net.Addr { } func (b *Broadcaster) GetCachedMessageCount() int { - return b.catchupBuffer.GetMessageCount() + return int(b.backlog.Count()) } func (b *Broadcaster) Initialize() error { diff --git a/broadcaster/broadcaster_test.go b/broadcaster/broadcaster_test.go index a97facef30..8ac06e9705 100644 --- a/broadcaster/broadcaster_test.go +++ b/broadcaster/broadcaster_test.go @@ -44,7 +44,7 @@ type messageCountPredicate struct { } func (p *messageCountPredicate) Test() bool { - p.was = p.b.catchupBuffer.GetMessageCount() + p.was = p.b.GetCachedMessageCount() return p.was == p.expected } @@ -78,26 +78,30 @@ func TestBroadcasterMessagesRemovedOnConfirmation(t *testing.T) { waitUntilUpdated(t, expectMessageCount(3, "after 3 messages")) Require(t, b.BroadcastSingle(arbostypes.EmptyTestMessageWithMetadata, 4)) waitUntilUpdated(t, expectMessageCount(4, "after 4 messages")) + Require(t, b.BroadcastSingle(arbostypes.EmptyTestMessageWithMetadata, 5)) + waitUntilUpdated(t, expectMessageCount(5, "after 4 messages")) + Require(t, b.BroadcastSingle(arbostypes.EmptyTestMessageWithMetadata, 6)) + waitUntilUpdated(t, expectMessageCount(6, "after 4 messages")) - b.Confirm(1) - waitUntilUpdated(t, expectMessageCount(3, - "after 4 messages, 1 cleared by confirm")) - - b.Confirm(2) + b.Confirm(4) waitUntilUpdated(t, expectMessageCount(2, - "after 4 messages, 2 cleared by confirm")) + "after 6 messages, 4 cleared by confirm")) - b.Confirm(1) - waitUntilUpdated(t, expectMessageCount(2, + b.Confirm(5) + waitUntilUpdated(t, expectMessageCount(1, + "after 6 messages, 5 cleared by confirm")) + + b.Confirm(4) + waitUntilUpdated(t, expectMessageCount(1, "nothing changed because confirmed sequence number before cache")) - b.Confirm(2) - Require(t, b.BroadcastSingle(arbostypes.EmptyTestMessageWithMetadata, 5)) - waitUntilUpdated(t, expectMessageCount(3, - "after 5 messages, 2 cleared by confirm")) + b.Confirm(5) + Require(t, b.BroadcastSingle(arbostypes.EmptyTestMessageWithMetadata, 7)) + waitUntilUpdated(t, expectMessageCount(2, + "after 7 messages, 5 cleared by confirm")) // Confirm not-yet-seen or already confirmed/cleared sequence numbers twice to force clearing cache - b.Confirm(6) + b.Confirm(8) waitUntilUpdated(t, expectMessageCount(0, "clear all messages after confirmed 1 beyond latest")) } diff --git a/broadcaster/message/message.go b/broadcaster/message/message.go new file mode 100644 index 0000000000..f436e765cb --- /dev/null +++ b/broadcaster/message/message.go @@ -0,0 +1,46 @@ +package message + +import ( + "github.com/ethereum/go-ethereum/common" + "github.com/offchainlabs/nitro/arbos/arbostypes" + "github.com/offchainlabs/nitro/arbutil" +) + +const ( + V1 = 1 +) + +// BroadcastMessage is the base message type for messages to send over the network. +// +// Acts as a variant holding the message types. The type of the message is +// indicated by whichever of the fields is non-empty. The fields holding the message +// types are annotated with omitempty so only the populated message is sent as +// json. The message fields should be pointers or slices and end with +// "Messages" or "Message". +// +// The format is forwards compatible, ie if a json BroadcastMessage is received that +// has fields that are not in the Go struct then deserialization will succeed +// skip the unknown field [1] +// +// References: +// [1] https://pkg.go.dev/encoding/json#Unmarshal +type BroadcastMessage struct { + Version int `json:"version"` + // TODO better name than messages since there are different types of messages + Messages []*BroadcastFeedMessage `json:"messages,omitempty"` + ConfirmedSequenceNumberMessage *ConfirmedSequenceNumberMessage `json:"confirmedSequenceNumberMessage,omitempty"` +} + +type BroadcastFeedMessage struct { + SequenceNumber arbutil.MessageIndex `json:"sequenceNumber"` + Message arbostypes.MessageWithMetadata `json:"message"` + Signature []byte `json:"signature"` +} + +func (m *BroadcastFeedMessage) Hash(chainId uint64) (common.Hash, error) { + return m.Message.Hash(m.SequenceNumber, chainId) +} + +type ConfirmedSequenceNumberMessage struct { + SequenceNumber arbutil.MessageIndex `json:"sequenceNumber"` +} diff --git a/broadcaster/broadcaster_serialization_test.go b/broadcaster/message/message_serialization_test.go similarity index 98% rename from broadcaster/broadcaster_serialization_test.go rename to broadcaster/message/message_serialization_test.go index 64adb49126..c3e14a86ae 100644 --- a/broadcaster/broadcaster_serialization_test.go +++ b/broadcaster/message/message_serialization_test.go @@ -1,7 +1,7 @@ // Copyright 2021-2022, Offchain Labs, Inc. // For license information, see https://github.com/nitro/blob/master/LICENSE -package broadcaster +package message import ( "bytes" diff --git a/broadcaster/message/message_test_utils.go b/broadcaster/message/message_test_utils.go new file mode 100644 index 0000000000..0943b49c60 --- /dev/null +++ b/broadcaster/message/message_test_utils.go @@ -0,0 +1,29 @@ +package message + +import ( + "github.com/offchainlabs/nitro/arbos/arbostypes" + "github.com/offchainlabs/nitro/arbutil" +) + +func CreateDummyBroadcastMessage(seqNums []arbutil.MessageIndex) *BroadcastMessage { + return &BroadcastMessage{ + Messages: CreateDummyBroadcastMessages(seqNums), + } +} + +func CreateDummyBroadcastMessages(seqNums []arbutil.MessageIndex) []*BroadcastFeedMessage { + return CreateDummyBroadcastMessagesImpl(seqNums, len(seqNums)) +} + +func CreateDummyBroadcastMessagesImpl(seqNums []arbutil.MessageIndex, length int) []*BroadcastFeedMessage { + broadcastMessages := make([]*BroadcastFeedMessage, 0, length) + for _, seqNum := range seqNums { + broadcastMessage := &BroadcastFeedMessage{ + SequenceNumber: seqNum, + Message: arbostypes.EmptyTestMessageWithMetadata, + } + broadcastMessages = append(broadcastMessages, broadcastMessage) + } + + return broadcastMessages +} diff --git a/broadcaster/sequencenumbercatchupbuffer.go b/broadcaster/sequencenumbercatchupbuffer.go deleted file mode 100644 index bdd3e60c5b..0000000000 --- a/broadcaster/sequencenumbercatchupbuffer.go +++ /dev/null @@ -1,193 +0,0 @@ -// Copyright 2021-2022, Offchain Labs, Inc. -// For license information, see https://github.com/nitro/blob/master/LICENSE - -package broadcaster - -import ( - "errors" - "sync/atomic" - "time" - - "github.com/ethereum/go-ethereum/log" - "github.com/ethereum/go-ethereum/metrics" - - "github.com/offchainlabs/nitro/arbutil" - "github.com/offchainlabs/nitro/wsbroadcastserver" -) - -const ( - // Do not send cache if requested seqnum is older than last cached minus maxRequestedSeqNumOffset - maxRequestedSeqNumOffset = arbutil.MessageIndex(10_000) -) - -var ( - confirmedSequenceNumberGauge = metrics.NewRegisteredGauge("arb/sequencenumber/confirmed", nil) - cachedMessagesSentHistogram = metrics.NewRegisteredHistogram("arb/feed/clients/cache/sent", nil, metrics.NewBoundedHistogramSample()) -) - -type SequenceNumberCatchupBuffer struct { - messages []*BroadcastFeedMessage - messageCount int32 - limitCatchup func() bool - maxCatchup func() int -} - -func NewSequenceNumberCatchupBuffer(limitCatchup func() bool, maxCatchup func() int) *SequenceNumberCatchupBuffer { - return &SequenceNumberCatchupBuffer{ - limitCatchup: limitCatchup, - maxCatchup: maxCatchup, - } -} - -func (b *SequenceNumberCatchupBuffer) getCacheMessages(requestedSeqNum arbutil.MessageIndex) *BroadcastMessage { - if len(b.messages) == 0 { - return nil - } - var startingIndex int32 - // Ignore messages older than requested sequence number - firstCachedSeqNum := b.messages[0].SequenceNumber - if firstCachedSeqNum < requestedSeqNum { - lastCachedSeqNum := firstCachedSeqNum + arbutil.MessageIndex(len(b.messages)-1) - if lastCachedSeqNum < requestedSeqNum { - // Past end, nothing to return - return nil - } - startingIndex = int32(requestedSeqNum - firstCachedSeqNum) - if startingIndex >= int32(len(b.messages)) { - log.Error("unexpected startingIndex", "requestedSeqNum", requestedSeqNum, "firstCachedSeqNum", firstCachedSeqNum, "startingIndex", startingIndex, "lastCachedSeqNum", lastCachedSeqNum, "cacheLength", len(b.messages)) - return nil - } - if b.messages[startingIndex].SequenceNumber != requestedSeqNum { - log.Error("requestedSeqNum not found where expected", "requestedSeqNum", requestedSeqNum, "firstCachedSeqNum", firstCachedSeqNum, "startingIndex", startingIndex, "foundSeqNum", b.messages[startingIndex].SequenceNumber) - return nil - } - } else if b.limitCatchup() && firstCachedSeqNum > maxRequestedSeqNumOffset && requestedSeqNum < (firstCachedSeqNum-maxRequestedSeqNumOffset) { - // Requested seqnum is too old, don't send any cache - return nil - } - - messagesToSend := b.messages[startingIndex:] - if len(messagesToSend) > 0 { - bm := BroadcastMessage{ - Version: 1, - Messages: messagesToSend, - } - - return &bm - } - - return nil -} - -func (b *SequenceNumberCatchupBuffer) OnRegisterClient(clientConnection *wsbroadcastserver.ClientConnection) (error, int, time.Duration) { - start := time.Now() - bm := b.getCacheMessages(clientConnection.RequestedSeqNum()) - var bmCount int - if bm != nil { - bmCount = len(bm.Messages) - } - if bm != nil { - // send the newly connected client the requested messages - err := clientConnection.Write(bm) - if err != nil { - log.Error("error sending client cached messages", "error", err, "client", clientConnection.Name, "elapsed", time.Since(start)) - return err, 0, 0 - } - } - - cachedMessagesSentHistogram.Update(int64(bmCount)) - - return nil, bmCount, time.Since(start) -} - -// Takes as input an index into the messages array, not a message index -func (b *SequenceNumberCatchupBuffer) pruneBufferToIndex(idx int) { - b.messages = b.messages[idx:] - if len(b.messages) > 10 && cap(b.messages) > len(b.messages)*10 { - // Too much spare capacity, copy to fresh slice to reset memory usage - b.messages = append([]*BroadcastFeedMessage(nil), b.messages[:len(b.messages)]...) - } -} - -func (b *SequenceNumberCatchupBuffer) deleteConfirmed(confirmedSequenceNumber arbutil.MessageIndex) { - if len(b.messages) == 0 { - return - } - - firstSequenceNumber := b.messages[0].SequenceNumber - - if confirmedSequenceNumber < firstSequenceNumber { - // Confirmed sequence number is older than cache, so nothing to do - return - } - - confirmedIndex := uint64(confirmedSequenceNumber - firstSequenceNumber) - - if confirmedIndex >= uint64(len(b.messages)) { - log.Error("ConfirmedSequenceNumber is past the end of stored messages", "confirmedSequenceNumber", confirmedSequenceNumber, "firstSequenceNumber", firstSequenceNumber, "cacheLength", len(b.messages)) - b.messages = nil - return - } - - if b.messages[confirmedIndex].SequenceNumber != confirmedSequenceNumber { - // Log instead of returning error here so that the message will be sent to downstream - // relays to also cause them to be cleared. - log.Error("Invariant violation: confirmedSequenceNumber is not where expected, clearing buffer", "confirmedSequenceNumber", confirmedSequenceNumber, "firstSequenceNumber", firstSequenceNumber, "cacheLength", len(b.messages), "foundSequenceNumber", b.messages[confirmedIndex].SequenceNumber) - b.messages = nil - return - } - - b.pruneBufferToIndex(int(confirmedIndex) + 1) -} - -func (b *SequenceNumberCatchupBuffer) OnDoBroadcast(bmi interface{}) error { - broadcastMessage, ok := bmi.(BroadcastMessage) - if !ok { - msg := "requested to broadcast message of unknown type" - log.Error(msg) - return errors.New(msg) - } - defer func() { atomic.StoreInt32(&b.messageCount, int32(len(b.messages))) }() - - if confirmMsg := broadcastMessage.ConfirmedSequenceNumberMessage; confirmMsg != nil { - b.deleteConfirmed(confirmMsg.SequenceNumber) - confirmedSequenceNumberGauge.Update(int64(confirmMsg.SequenceNumber)) - } - - maxCatchup := b.maxCatchup() - if maxCatchup == 0 { - b.messages = nil - return nil - } - - for _, newMsg := range broadcastMessage.Messages { - if len(b.messages) == 0 { - // Add to empty list - b.messages = append(b.messages, newMsg) - } else if expectedSequenceNumber := b.messages[len(b.messages)-1].SequenceNumber + 1; newMsg.SequenceNumber == expectedSequenceNumber { - // Next sequence number to add to end of list - b.messages = append(b.messages, newMsg) - } else if newMsg.SequenceNumber > expectedSequenceNumber { - log.Warn( - "Message requested to be broadcast has unexpected sequence number; discarding to seqNum from catchup buffer", - "seqNum", newMsg.SequenceNumber, - "expectedSeqNum", expectedSequenceNumber, - ) - b.messages = nil - b.messages = append(b.messages, newMsg) - } else { - log.Info("Skipping already seen message", "seqNum", newMsg.SequenceNumber) - } - } - - if maxCatchup >= 0 && len(b.messages) > maxCatchup { - b.pruneBufferToIndex(len(b.messages) - maxCatchup) - } - - return nil - -} - -func (b *SequenceNumberCatchupBuffer) GetMessageCount() int { - return int(atomic.LoadInt32(&b.messageCount)) -} diff --git a/broadcaster/sequencenumbercatchupbuffer_test.go b/broadcaster/sequencenumbercatchupbuffer_test.go deleted file mode 100644 index fc6655057e..0000000000 --- a/broadcaster/sequencenumbercatchupbuffer_test.go +++ /dev/null @@ -1,271 +0,0 @@ -/* - * Copyright 2020-2021, Offchain Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package broadcaster - -import ( - "strings" - "testing" - - "github.com/offchainlabs/nitro/arbos/arbostypes" - "github.com/offchainlabs/nitro/arbutil" - "github.com/offchainlabs/nitro/util/arbmath" -) - -func TestGetEmptyCacheMessages(t *testing.T) { - buffer := SequenceNumberCatchupBuffer{ - messages: nil, - messageCount: 0, - limitCatchup: func() bool { return false }, - maxCatchup: func() int { return -1 }, - } - - // Get everything - bm := buffer.getCacheMessages(0) - if bm != nil { - t.Error("shouldn't have returned anything") - } -} - -func createDummyBroadcastMessages(seqNums []arbutil.MessageIndex) []*BroadcastFeedMessage { - return createDummyBroadcastMessagesImpl(seqNums, len(seqNums)) -} -func createDummyBroadcastMessagesImpl(seqNums []arbutil.MessageIndex, length int) []*BroadcastFeedMessage { - broadcastMessages := make([]*BroadcastFeedMessage, 0, length) - for _, seqNum := range seqNums { - broadcastMessage := &BroadcastFeedMessage{ - SequenceNumber: seqNum, - Message: arbostypes.EmptyTestMessageWithMetadata, - } - broadcastMessages = append(broadcastMessages, broadcastMessage) - } - - return broadcastMessages -} - -func TestGetCacheMessages(t *testing.T) { - indexes := []arbutil.MessageIndex{40, 41, 42, 43, 44, 45, 46} - buffer := SequenceNumberCatchupBuffer{ - messages: createDummyBroadcastMessages(indexes), - messageCount: int32(len(indexes)), - limitCatchup: func() bool { return false }, - maxCatchup: func() int { return -1 }, - } - - // Get everything - bm := buffer.getCacheMessages(0) - if len(bm.Messages) != 7 { - t.Error("didn't return all messages") - } - - // Get everything - bm = buffer.getCacheMessages(1) - if len(bm.Messages) != 7 { - t.Error("didn't return all messages") - } - - // Get everything - bm = buffer.getCacheMessages(40) - if len(bm.Messages) != 7 { - t.Error("didn't return all messages") - } - - // Get nothing - bm = buffer.getCacheMessages(100) - if bm != nil { - t.Error("should not have returned anything") - } - - // Test single - bm = buffer.getCacheMessages(46) - if bm == nil { - t.Fatal("nothing returned") - } - if len(bm.Messages) != 1 { - t.Errorf("expected 1 message, got %d messages", len(bm.Messages)) - } - if bm.Messages[0].SequenceNumber != 46 { - t.Errorf("expected sequence number 46, got %d", bm.Messages[0].SequenceNumber) - } - - // Test extremes - bm = buffer.getCacheMessages(arbutil.MessageIndex(^uint64(0))) - if bm != nil { - t.Fatal("should not have returned anything") - } -} - -func TestDeleteConfirmedNil(t *testing.T) { - buffer := SequenceNumberCatchupBuffer{ - messages: nil, - messageCount: 0, - limitCatchup: func() bool { return false }, - maxCatchup: func() int { return -1 }, - } - - buffer.deleteConfirmed(0) - if len(buffer.messages) != 0 { - t.Error("nothing should be present") - } -} - -func TestDeleteConfirmInvalidOrder(t *testing.T) { - indexes := []arbutil.MessageIndex{40, 42} - buffer := SequenceNumberCatchupBuffer{ - messages: createDummyBroadcastMessages(indexes), - messageCount: int32(len(indexes)), - limitCatchup: func() bool { return false }, - maxCatchup: func() int { return -1 }, - } - - // Confirm before cache - buffer.deleteConfirmed(41) - if len(buffer.messages) != 0 { - t.Error("cache not in contiguous order should have caused everything to be deleted") - } -} - -func TestDeleteConfirmed(t *testing.T) { - indexes := []arbutil.MessageIndex{40, 41, 42, 43, 44, 45, 46} - buffer := SequenceNumberCatchupBuffer{ - messages: createDummyBroadcastMessages(indexes), - messageCount: int32(len(indexes)), - limitCatchup: func() bool { return false }, - maxCatchup: func() int { return -1 }, - } - - // Confirm older than cache - buffer.deleteConfirmed(39) - if len(buffer.messages) != 7 { - t.Error("nothing should have been deleted") - } - -} -func TestDeleteFreeMem(t *testing.T) { - indexes := []arbutil.MessageIndex{40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51} - buffer := SequenceNumberCatchupBuffer{ - messages: createDummyBroadcastMessagesImpl(indexes, len(indexes)*10+1), - messageCount: int32(len(indexes)), - limitCatchup: func() bool { return false }, - maxCatchup: func() int { return -1 }, - } - - // Confirm older than cache - buffer.deleteConfirmed(40) - if cap(buffer.messages) > 20 { - t.Error("extra memory was not freed, cap: ", cap(buffer.messages)) - } - -} - -func TestBroadcastBadMessage(t *testing.T) { - buffer := SequenceNumberCatchupBuffer{ - messages: nil, - messageCount: 0, - limitCatchup: func() bool { return false }, - maxCatchup: func() int { return -1 }, - } - - var foo int - err := buffer.OnDoBroadcast(foo) - if err == nil { - t.Error("expected error") - } - if !strings.Contains(err.Error(), "unknown type") { - t.Error("unexpected type") - } -} - -func TestBroadcastPastSeqNum(t *testing.T) { - indexes := []arbutil.MessageIndex{40} - buffer := SequenceNumberCatchupBuffer{ - messages: createDummyBroadcastMessagesImpl(indexes, len(indexes)*10+1), - messageCount: int32(len(indexes)), - limitCatchup: func() bool { return false }, - maxCatchup: func() int { return -1 }, - } - - bm := BroadcastMessage{ - Messages: []*BroadcastFeedMessage{ - { - SequenceNumber: 39, - }, - }, - } - err := buffer.OnDoBroadcast(bm) - if err != nil { - t.Error("expected error") - } - -} - -func TestBroadcastFutureSeqNum(t *testing.T) { - indexes := []arbutil.MessageIndex{40} - buffer := SequenceNumberCatchupBuffer{ - messages: createDummyBroadcastMessagesImpl(indexes, len(indexes)*10+1), - messageCount: int32(len(indexes)), - limitCatchup: func() bool { return false }, - maxCatchup: func() int { return -1 }, - } - - bm := BroadcastMessage{ - Messages: []*BroadcastFeedMessage{ - { - SequenceNumber: 42, - }, - }, - } - err := buffer.OnDoBroadcast(bm) - if err != nil { - t.Error("expected error") - } - -} - -func TestMaxCatchupBufferSize(t *testing.T) { - limit := 5 - buffer := SequenceNumberCatchupBuffer{ - messages: nil, - messageCount: 0, - limitCatchup: func() bool { return false }, - maxCatchup: func() int { return limit }, - } - - firstMessage := 10 - for i := firstMessage; i <= 20; i += 2 { - bm := BroadcastMessage{ - Messages: []*BroadcastFeedMessage{ - { - SequenceNumber: arbutil.MessageIndex(i), - }, - { - SequenceNumber: arbutil.MessageIndex(i + 1), - }, - }, - } - err := buffer.OnDoBroadcast(bm) - Require(t, err) - haveMessages := buffer.getCacheMessages(0) - expectedCount := arbmath.MinInt(i+len(bm.Messages)-firstMessage, limit) - if len(haveMessages.Messages) != expectedCount { - t.Errorf("after broadcasting messages %v and %v, expected to have %v messages but got %v", i, i+1, expectedCount, len(haveMessages.Messages)) - } - expectedFirstMessage := arbutil.MessageIndex(arbmath.MaxInt(firstMessage, i+len(bm.Messages)-limit)) - if haveMessages.Messages[0].SequenceNumber != expectedFirstMessage { - t.Errorf("after broadcasting messages %v and %v, expected the first message to be %v but got %v", i, i+1, expectedFirstMessage, haveMessages.Messages[0].SequenceNumber) - } - } -} diff --git a/relay/relay.go b/relay/relay.go index 4288902865..8e29971384 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -16,6 +16,7 @@ import ( "github.com/offchainlabs/nitro/broadcastclient" "github.com/offchainlabs/nitro/broadcastclients" "github.com/offchainlabs/nitro/broadcaster" + m "github.com/offchainlabs/nitro/broadcaster/message" "github.com/offchainlabs/nitro/cmd/genericconf" "github.com/offchainlabs/nitro/cmd/util/confighelpers" "github.com/offchainlabs/nitro/util/sharedmetrics" @@ -28,14 +29,14 @@ type Relay struct { broadcastClients *broadcastclients.BroadcastClients broadcaster *broadcaster.Broadcaster confirmedSequenceNumberChan chan arbutil.MessageIndex - messageChan chan broadcaster.BroadcastFeedMessage + messageChan chan m.BroadcastFeedMessage } type MessageQueue struct { - queue chan broadcaster.BroadcastFeedMessage + queue chan m.BroadcastFeedMessage } -func (q *MessageQueue) AddBroadcastMessages(feedMessages []*broadcaster.BroadcastFeedMessage) error { +func (q *MessageQueue) AddBroadcastMessages(feedMessages []*m.BroadcastFeedMessage) error { for _, feedMessage := range feedMessages { q.queue <- *feedMessage } @@ -45,7 +46,7 @@ func (q *MessageQueue) AddBroadcastMessages(feedMessages []*broadcaster.Broadcas func NewRelay(config *Config, feedErrChan chan error) (*Relay, error) { - q := MessageQueue{make(chan broadcaster.BroadcastFeedMessage, config.Queue)} + q := MessageQueue{make(chan m.BroadcastFeedMessage, config.Queue)} confirmedSequenceNumberListener := make(chan arbutil.MessageIndex, config.Queue) diff --git a/wsbroadcastserver/clientconnection.go b/wsbroadcastserver/clientconnection.go index bdbeccfd23..49cd2af7e6 100644 --- a/wsbroadcastserver/clientconnection.go +++ b/wsbroadcastserver/clientconnection.go @@ -5,6 +5,7 @@ package wsbroadcastserver import ( "context" + "errors" "fmt" "math/rand" "net" @@ -12,7 +13,10 @@ import ( "sync/atomic" "time" + "github.com/ethereum/go-ethereum/log" "github.com/offchainlabs/nitro/arbutil" + "github.com/offchainlabs/nitro/broadcaster/backlog" + m "github.com/offchainlabs/nitro/broadcaster/message" "github.com/gobwas/ws" "github.com/gobwas/ws/wsflate" @@ -20,6 +24,18 @@ import ( "github.com/offchainlabs/nitro/util/stopwaiter" ) +var errContextDone = errors.New("context done") + +type message struct { + data []byte + sequenceNumber *arbutil.MessageIndex +} + +type ClientConnectionAction struct { + cc *ClientConnection + create bool +} + // ClientConnection represents client connection. type ClientConnection struct { stopwaiter.StopWaiter @@ -31,11 +47,15 @@ type ClientConnection struct { desc *netpoll.Desc Name string - clientManager *ClientManager + clientAction chan ClientConnectionAction requestedSeqNum arbutil.MessageIndex + LastSentSeqNum atomic.Uint64 lastHeardUnix int64 - out chan []byte + out chan message + backlog backlog.Backlog + registered chan bool + backlogSent bool compression bool flateReader *wsflate.Reader @@ -46,11 +66,13 @@ type ClientConnection struct { func NewClientConnection( conn net.Conn, desc *netpoll.Desc, - clientManager *ClientManager, + clientAction chan ClientConnectionAction, requestedSeqNum arbutil.MessageIndex, connectingIP net.IP, compression bool, + maxSendQueue int, delay time.Duration, + bklg backlog.Backlog, ) *ClientConnection { return &ClientConnection{ conn: conn, @@ -58,13 +80,16 @@ func NewClientConnection( desc: desc, creation: time.Now(), Name: fmt.Sprintf("%s@%s-%d", connectingIP, conn.RemoteAddr(), rand.Intn(10)), - clientManager: clientManager, + clientAction: clientAction, requestedSeqNum: requestedSeqNum, lastHeardUnix: time.Now().Unix(), - out: make(chan []byte, clientManager.config().MaxSendQueue), + out: make(chan message, maxSendQueue), compression: compression, flateReader: NewFlateReader(), delay: delay, + backlog: bklg, + registered: make(chan bool, 1), + backlogSent: false, } } @@ -76,42 +101,158 @@ func (cc *ClientConnection) Compression() bool { return cc.compression } +// Register sends the ClientConnection to be registered with the ClientManager. +func (cc *ClientConnection) Register() { + cc.clientAction <- ClientConnectionAction{ + cc: cc, + create: true, + } +} + +// Remove sends the ClientConnection to be removed from the ClientManager. +func (cc *ClientConnection) Remove() { + cc.clientAction <- ClientConnectionAction{ + cc: cc, + create: false, + } +} + +func (cc *ClientConnection) writeBacklog(ctx context.Context, segment backlog.BacklogSegment) error { + var prevSegment backlog.BacklogSegment + for !backlog.IsBacklogSegmentNil(segment) { + // must get the next segment before the messages to be sent are + // retrieved ensures another segment is not added in between calls. + prevSegment = segment + segment = segment.Next() + + select { + case <-ctx.Done(): + return errContextDone + default: + } + + msgs := prevSegment.Messages() + if prevSegment.Contains(uint64(cc.requestedSeqNum)) { + requestedIdx := int(cc.requestedSeqNum) - int(prevSegment.Start()) + msgs = msgs[requestedIdx:] + } + bm := &m.BroadcastMessage{ + Version: m.V1, + Messages: msgs, + } + err := cc.writeBroadcastMessage(bm) + if err != nil { + return err + } + + // do not use prevSegment.End() method, must figure out the last + // sequence number from the messages that were actually sent in case + // more messages are added. + end := uint64(msgs[len(msgs)-1].SequenceNumber) + cc.LastSentSeqNum.Store(end) + log.Debug("segment sent to client", "client", cc.Name, "sentCount", len(bm.Messages), "lastSentSeqNum", end) + } + return nil +} + +func (cc *ClientConnection) writeBroadcastMessage(bm *m.BroadcastMessage) error { + notCompressed, compressed, err := serializeMessage(bm, !cc.compression, cc.compression) + if err != nil { + return err + } + + var data []byte + if cc.compression { + data = compressed.Bytes() + } else { + data = notCompressed.Bytes() + } + err = cc.writeRaw(data) + if err != nil { + return err + } + return nil +} + func (cc *ClientConnection) Start(parentCtx context.Context) { cc.StopWaiter.Start(parentCtx, cc) cc.LaunchThread(func(ctx context.Context) { + // A delay may be configured, ensures the Broadcaster delays before any + // messages are sent to the client. The ClientConnection has not been + // registered so the out channel filling is not a concern. if cc.delay != 0 { - var delayQueue [][]byte t := time.NewTimer(cc.delay) - done := false - for !done { - select { - case <-ctx.Done(): - return - case data := <-cc.out: - delayQueue = append(delayQueue, data) - case <-t.C: - for _, data := range delayQueue { - err := cc.writeRaw(data) - if err != nil { - logWarn(err, "error writing data to client") - cc.clientManager.Remove(cc) - return - } - } - done = true - } + select { + case <-ctx.Done(): + return + case <-t.C: + } + } + + // Send the current backlog before registering the ClientConnection in + // case the backlog is very large + segment := cc.backlog.Head() + if !backlog.IsBacklogSegmentNil(segment) && segment.Start() < uint64(cc.requestedSeqNum) { + s, err := cc.backlog.Lookup(uint64(cc.requestedSeqNum)) + if err != nil { + logWarn(err, "error finding requested sequence number in backlog: sending the entire backlog instead") + } else { + segment = s } } + err := cc.writeBacklog(ctx, segment) + if errors.Is(err, errContextDone) { + return + } else if err != nil { + logWarn(err, "error writing messages from backlog") + cc.Remove() + return + } + cc.Register() + timer := time.NewTimer(5 * time.Second) + select { + case <-ctx.Done(): + return + case <-cc.registered: + log.Debug("ClientConnection registered with ClientManager", "client", cc.Name) + case <-timer.C: + log.Error("timed out waiting for ClientConnection to register with ClientManager", "client", cc.Name) + } + + // broadcast any new messages sent to the out channel for { select { case <-ctx.Done(): return - case data := <-cc.out: - err := cc.writeRaw(data) + case msg := <-cc.out: + if msg.sequenceNumber != nil && uint64(*msg.sequenceNumber) <= cc.LastSentSeqNum.Load() { + log.Debug("client has already sent message with this sequence number, skipping the message", "client", cc.Name, "sequence number", *msg.sequenceNumber) + continue + } + + expSeqNum := cc.LastSentSeqNum.Load() + 1 + if !cc.backlogSent && msg.sequenceNumber != nil && uint64(*msg.sequenceNumber) > expSeqNum { + catchupSeqNum := uint64(*msg.sequenceNumber) - 1 + bm, err := cc.backlog.Get(expSeqNum, catchupSeqNum) + if err != nil { + logWarn(err, fmt.Sprintf("error reading messages %d to %d from backlog", expSeqNum, catchupSeqNum)) + return + } + + err = cc.writeBroadcastMessage(bm) + if err != nil { + logWarn(err, fmt.Sprintf("error writing messages %d to %d from backlog", expSeqNum, catchupSeqNum)) + cc.Remove() + return + } + } + cc.backlogSent = true + + err := cc.writeRaw(msg.data) if err != nil { logWarn(err, "error writing data to client") - cc.clientManager.Remove(cc) + cc.Remove() return } } @@ -119,6 +260,12 @@ func (cc *ClientConnection) Start(parentCtx context.Context) { }) } +// Registered is used by the ClientManager to indicate that ClientConnection +// has been registered with the ClientManager +func (cc *ClientConnection) Registered() { + cc.registered <- true +} + func (cc *ClientConnection) StopOnly() { // Ignore errors from conn.Close since we are just shutting down _ = cc.conn.Close() @@ -161,23 +308,6 @@ func (cc *ClientConnection) readRequest(ctx context.Context, timeout time.Durati return data, opCode, err } -func (cc *ClientConnection) Write(x interface{}) error { - cc.ioMutex.Lock() - defer cc.ioMutex.Unlock() - - notCompressed, compressed, err := serializeMessage(cc.clientManager, x, !cc.compression, cc.compression) - if err != nil { - return err - } - - if cc.compression { - cc.out <- compressed.Bytes() - } else { - cc.out <- notCompressed.Bytes() - } - return nil -} - func (cc *ClientConnection) writeRaw(p []byte) error { cc.ioMutex.Lock() defer cc.ioMutex.Unlock() diff --git a/wsbroadcastserver/clientmanager.go b/wsbroadcastserver/clientmanager.go index f140e6254f..a88716756a 100644 --- a/wsbroadcastserver/clientmanager.go +++ b/wsbroadcastserver/clientmanager.go @@ -10,7 +10,6 @@ import ( "encoding/json" "fmt" "io" - "net" "strings" "sync/atomic" "time" @@ -25,27 +24,21 @@ import ( "github.com/ethereum/go-ethereum/metrics" "github.com/offchainlabs/nitro/arbutil" + "github.com/offchainlabs/nitro/broadcaster/backlog" + m "github.com/offchainlabs/nitro/broadcaster/message" "github.com/offchainlabs/nitro/util/stopwaiter" ) var ( - clientsCurrentGauge = metrics.NewRegisteredGauge("arb/feed/clients/current", nil) - clientsConnectCount = metrics.NewRegisteredCounter("arb/feed/clients/connect", nil) - clientsDisconnectCount = metrics.NewRegisteredCounter("arb/feed/clients/disconnect", nil) - clientsTotalSuccessCounter = metrics.NewRegisteredCounter("arb/feed/clients/success", nil) - clientsTotalFailedRegisterCounter = metrics.NewRegisteredCounter("arb/feed/clients/failed/register", nil) - clientsTotalFailedUpgradeCounter = metrics.NewRegisteredCounter("arb/feed/clients/failed/upgrade", nil) - clientsTotalFailedWorkerCounter = metrics.NewRegisteredCounter("arb/feed/clients/failed/worker", nil) - clientsDurationHistogram = metrics.NewRegisteredHistogram("arb/feed/clients/duration", nil, metrics.NewBoundedHistogramSample()) + clientsCurrentGauge = metrics.NewRegisteredGauge("arb/feed/clients/current", nil) + clientsConnectCount = metrics.NewRegisteredCounter("arb/feed/clients/connect", nil) + clientsDisconnectCount = metrics.NewRegisteredCounter("arb/feed/clients/disconnect", nil) + clientsTotalSuccessCounter = metrics.NewRegisteredCounter("arb/feed/clients/success", nil) + clientsTotalFailedUpgradeCounter = metrics.NewRegisteredCounter("arb/feed/clients/failed/upgrade", nil) + clientsTotalFailedWorkerCounter = metrics.NewRegisteredCounter("arb/feed/clients/failed/worker", nil) + clientsDurationHistogram = metrics.NewRegisteredHistogram("arb/feed/clients/duration", nil, metrics.NewBoundedHistogramSample()) ) -// CatchupBuffer is a Protocol-specific client catch-up logic can be injected using this interface -type CatchupBuffer interface { - OnRegisterClient(*ClientConnection) (error, int, time.Duration) - OnDoBroadcast(interface{}) error - GetMessageCount() int -} - // ClientManager manages client connections type ClientManager struct { stopwaiter.StopWaiter @@ -54,30 +47,24 @@ type ClientManager struct { clientCount int32 pool *gopool.Pool poller netpoll.Poller - broadcastChan chan interface{} + broadcastChan chan *m.BroadcastMessage clientAction chan ClientConnectionAction config BroadcasterConfigFetcher - catchupBuffer CatchupBuffer - flateWriter *flate.Writer + backlog backlog.Backlog connectionLimiter *ConnectionLimiter } -type ClientConnectionAction struct { - cc *ClientConnection - create bool -} - -func NewClientManager(poller netpoll.Poller, configFetcher BroadcasterConfigFetcher, catchupBuffer CatchupBuffer) *ClientManager { +func NewClientManager(poller netpoll.Poller, configFetcher BroadcasterConfigFetcher, bklg backlog.Backlog) *ClientManager { config := configFetcher() return &ClientManager{ poller: poller, pool: gopool.NewPool(config.Workers, config.Queue, 1), clientPtrMap: make(map[*ClientConnection]bool), - broadcastChan: make(chan interface{}, 1), + broadcastChan: make(chan *m.BroadcastMessage, 1), clientAction: make(chan ClientConnectionAction, 128), config: configFetcher, - catchupBuffer: catchupBuffer, + backlog: bklg, connectionLimiter: NewConnectionLimiter(func() *ConnectionLimiterConfig { return &configFetcher().ConnectionLimits }), } } @@ -89,6 +76,8 @@ func (cm *ClientManager) registerClient(ctx context.Context, clientConnection *C } }() + // TODO:(clamb) the clientsTotalFailedRegisterCounter was deleted after backlog logic moved to ClientConnection. Should this metric be reintroduced or will it be ok to just delete completely given the behaviour has changed, ask Lee + if cm.config().ConnectionLimits.Enable && !cm.connectionLimiter.Register(clientConnection.clientIp) { return fmt.Errorf("Connection limited %s", clientConnection.clientIp) } @@ -97,42 +86,12 @@ func (cm *ClientManager) registerClient(ctx context.Context, clientConnection *C clientsConnectCount.Inc(1) atomic.AddInt32(&cm.clientCount, 1) - err, sent, elapsed := cm.catchupBuffer.OnRegisterClient(clientConnection) - if err != nil { - clientsTotalFailedRegisterCounter.Inc(1) - if cm.config().ConnectionLimits.Enable { - cm.connectionLimiter.Release(clientConnection.clientIp) - } - return err - } - if cm.config().LogConnect { - log.Info("client registered", "client", clientConnection.Name, "requestedSeqNum", clientConnection.RequestedSeqNum(), "sentCount", sent, "elapsed", elapsed) - } - - clientConnection.Start(ctx) cm.clientPtrMap[clientConnection] = true clientsTotalSuccessCounter.Inc(1) return nil } -// Register registers new connection as a Client. -func (cm *ClientManager) Register( - conn net.Conn, - desc *netpoll.Desc, - requestedSeqNum arbutil.MessageIndex, - connectingIP net.IP, - compression bool, -) *ClientConnection { - createClient := ClientConnectionAction{ - NewClientConnection(conn, desc, cm, requestedSeqNum, connectingIP, compression, cm.config().ClientDelay), - true, - } - cm.clientAction <- createClient - - return createClient.cc -} - // removeAll removes all clients after main ClientManager thread exits func (cm *ClientManager) removeAll() { // Only called after main ClientManager thread exits, so remove client directly @@ -177,19 +136,12 @@ func (cm *ClientManager) removeClient(clientConnection *ClientConnection) { delete(cm.clientPtrMap, clientConnection) } -func (cm *ClientManager) Remove(clientConnection *ClientConnection) { - cm.clientAction <- ClientConnectionAction{ - clientConnection, - false, - } -} - func (cm *ClientManager) ClientCount() int32 { return atomic.LoadInt32(&cm.clientCount) } // Broadcast sends batch item to all clients. -func (cm *ClientManager) Broadcast(bm interface{}) { +func (cm *ClientManager) Broadcast(bm *m.BroadcastMessage) { if cm.Stopped() { // This should only occur if a reorg occurs after the broadcast server is stopped, // with the sequencer enabled but not the sequencer coordinator. @@ -199,16 +151,16 @@ func (cm *ClientManager) Broadcast(bm interface{}) { cm.broadcastChan <- bm } -func (cm *ClientManager) doBroadcast(bm interface{}) ([]*ClientConnection, error) { - if err := cm.catchupBuffer.OnDoBroadcast(bm); err != nil { +func (cm *ClientManager) doBroadcast(bm *m.BroadcastMessage) ([]*ClientConnection, error) { + if err := cm.backlog.Append(bm); err != nil { return nil, err } config := cm.config() // /-> wsutil.Writer -> not compressed msg buffer // bm -> json.Encoder -> io.MultiWriter -| - // \-> cm.flateWriter -> wsutil.Writer -> compressed msg buffer + // \-> flateWriter -> wsutil.Writer -> compressed msg buffer - notCompressed, compressed, err := serializeMessage(cm, bm, !config.RequireCompression, config.EnableCompression) + notCompressed, compressed, err := serializeMessage(bm, !config.RequireCompression, config.EnableCompression) if err != nil { return nil, err } @@ -234,8 +186,23 @@ func (cm *ClientManager) doBroadcast(bm interface{}) ([]*ClientConnection, error continue } } + + var seqNum *arbutil.MessageIndex + n := len(bm.Messages) + if n == 0 { + seqNum = nil + } else if n == 1 { + seqNum = &bm.Messages[0].SequenceNumber + } else { + return nil, fmt.Errorf("doBroadcast was sent %d BroadcastFeedMessages, it can only parse 1 BroadcastFeedMessage at a time", n) + } + + m := message{ + sequenceNumber: seqNum, + data: data, + } select { - case client.out <- data: + case client.out <- m: default: // Queue for client too backed up, disconnect instead of blocking on channel send sendQueueTooLargeCount++ @@ -254,7 +221,12 @@ func (cm *ClientManager) doBroadcast(bm interface{}) ([]*ClientConnection, error return clientDeleteList, nil } -func serializeMessage(cm *ClientManager, bm interface{}, enableNonCompressedOutput, enableCompressedOutput bool) (bytes.Buffer, bytes.Buffer, error) { +func serializeMessage(bm *m.BroadcastMessage, enableNonCompressedOutput, enableCompressedOutput bool) (bytes.Buffer, bytes.Buffer, error) { + flateWriter, err := flate.NewWriterDict(nil, DeflateCompressionLevel, GetStaticCompressorDictionary()) + if err != nil { + return bytes.Buffer{}, bytes.Buffer{}, fmt.Errorf("unable to create flate writer: %w", err) + } + var notCompressed bytes.Buffer var compressed bytes.Buffer writers := []io.Writer{} @@ -265,19 +237,12 @@ func serializeMessage(cm *ClientManager, bm interface{}, enableNonCompressedOutp writers = append(writers, notCompressedWriter) } if enableCompressedOutput { - if cm.flateWriter == nil { - var err error - cm.flateWriter, err = flate.NewWriterDict(nil, DeflateCompressionLevel, GetStaticCompressorDictionary()) - if err != nil { - return bytes.Buffer{}, bytes.Buffer{}, fmt.Errorf("unable to create flate writer: %w", err) - } - } compressedWriter = wsutil.NewWriter(&compressed, ws.StateServerSide|ws.StateExtended, ws.OpText) var msg wsflate.MessageState msg.SetCompressed(true) compressedWriter.SetExtensions(&msg) - cm.flateWriter.Reset(compressedWriter) - writers = append(writers, cm.flateWriter) + flateWriter.Reset(compressedWriter) + writers = append(writers, flateWriter) } multiWriter := io.MultiWriter(writers...) @@ -291,7 +256,7 @@ func serializeMessage(cm *ClientManager, bm interface{}, enableNonCompressedOutp } } if compressedWriter != nil { - if err := cm.flateWriter.Close(); err != nil { + if err := flateWriter.Close(); err != nil { return bytes.Buffer{}, bytes.Buffer{}, fmt.Errorf("unable to close flate writer: %w", err) } if err := compressedWriter.Flush(); err != nil { @@ -348,13 +313,31 @@ func (cm *ClientManager) Start(parentCtx context.Context) { // Log message already output in registerClient cm.removeClientImpl(clientAction.cc) } + clientAction.cc.Registered() } else { cm.removeClient(clientAction.cc) } case bm := <-cm.broadcastChan: var err error - clientDeleteList, err = cm.doBroadcast(bm) - logError(err, "failed to do broadcast") + for i, msg := range bm.Messages { + m := &m.BroadcastMessage{ + Version: bm.Version, + Messages: []*m.BroadcastFeedMessage{msg}, + } + // This ensures that only one message is sent with the confirmed sequence number + if i == 0 { + m.ConfirmedSequenceNumberMessage = bm.ConfirmedSequenceNumberMessage + } + clientDeleteList, err = cm.doBroadcast(m) + logError(err, "failed to do broadcast") + } + + // A message with ConfirmedSequenceNumberMessage could be sent without any messages + // this section ensures that message is still sent. + if len(bm.Messages) == 0 { + clientDeleteList, err = cm.doBroadcast(bm) + logError(err, "failed to do broadcast") + } case <-pingTimer.C: clientDeleteList = cm.verifyClients() pingTimer.Reset(cm.config().Ping) diff --git a/wsbroadcastserver/wsbroadcastserver.go b/wsbroadcastserver/wsbroadcastserver.go index d51b368400..eb47f8a635 100644 --- a/wsbroadcastserver/wsbroadcastserver.go +++ b/wsbroadcastserver/wsbroadcastserver.go @@ -25,6 +25,8 @@ import ( "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" "github.com/offchainlabs/nitro/arbutil" + "github.com/offchainlabs/nitro/broadcaster/backlog" + m "github.com/offchainlabs/nitro/broadcaster/message" ) var ( @@ -66,6 +68,7 @@ type BroadcasterConfig struct { MaxCatchup int `koanf:"max-catchup" reload:"hot"` ConnectionLimits ConnectionLimiterConfig `koanf:"connection-limits" reload:"hot"` ClientDelay time.Duration `koanf:"client-delay" reload:"hot"` + Backlog backlog.Config `koanf:"backlog" reload:"hot"` } func (bc *BroadcasterConfig) Validate() error { @@ -100,6 +103,7 @@ func BroadcasterConfigAddOptions(prefix string, f *flag.FlagSet) { f.Int(prefix+".max-catchup", DefaultBroadcasterConfig.MaxCatchup, "the maximum size of the catchup buffer (-1 means unlimited)") ConnectionLimiterConfigAddOptions(prefix+".connection-limits", f) f.Duration(prefix+".client-delay", DefaultBroadcasterConfig.ClientDelay, "delay the first messages sent to each client by this amount") + backlog.AddOptions(prefix+".backlog", f) } var DefaultBroadcasterConfig = BroadcasterConfig{ @@ -125,6 +129,7 @@ var DefaultBroadcasterConfig = BroadcasterConfig{ MaxCatchup: -1, ConnectionLimits: DefaultConnectionLimiterConfig, ClientDelay: 0, + Backlog: backlog.DefaultConfig, } var DefaultTestBroadcasterConfig = BroadcasterConfig{ @@ -150,6 +155,7 @@ var DefaultTestBroadcasterConfig = BroadcasterConfig{ MaxCatchup: -1, ConnectionLimits: DefaultConnectionLimiterConfig, ClientDelay: 0, + Backlog: backlog.DefaultTestConfig, } type WSBroadcastServer struct { @@ -163,18 +169,18 @@ type WSBroadcastServer struct { config BroadcasterConfigFetcher started bool clientManager *ClientManager - catchupBuffer CatchupBuffer + backlog backlog.Backlog chainId uint64 fatalErrChan chan error } -func NewWSBroadcastServer(config BroadcasterConfigFetcher, catchupBuffer CatchupBuffer, chainId uint64, fatalErrChan chan error) *WSBroadcastServer { +func NewWSBroadcastServer(config BroadcasterConfigFetcher, bklg backlog.Backlog, chainId uint64, fatalErrChan chan error) *WSBroadcastServer { return &WSBroadcastServer{ - config: config, - started: false, - catchupBuffer: catchupBuffer, - chainId: chainId, - fatalErrChan: fatalErrChan, + config: config, + started: false, + backlog: bklg, + chainId: chainId, + fatalErrChan: fatalErrChan, } } @@ -192,7 +198,7 @@ func (s *WSBroadcastServer) Initialize() error { // Make pool of X size, Y sized work queue and one pre-spawned // goroutine. - s.clientManager = NewClientManager(s.poller, s.config, s.catchupBuffer) + s.clientManager = NewClientManager(s.poller, s.config, s.backlog) return nil } @@ -372,7 +378,8 @@ func (s *WSBroadcastServer) StartWithHeader(ctx context.Context, header ws.Hands // Register incoming client in clientManager. safeConn := writeDeadliner{conn, config.WriteTimeout} - client := s.clientManager.Register(safeConn, desc, requestedSeqNum, connectingIP, compressionAccepted) + client := NewClientConnection(safeConn, desc, s.clientManager.clientAction, requestedSeqNum, connectingIP, compressionAccepted, s.config().MaxSendQueue, s.config().ClientDelay, s.backlog) + client.Start(ctx) // Subscribe to events about conn. err = s.poller.Start(desc, func(ev netpoll.Event) { @@ -380,7 +387,7 @@ func (s *WSBroadcastServer) StartWithHeader(ctx context.Context, header ws.Hands // ReadHup or Hup received, means the client has close the connection // remove it from the clientManager registry. log.Debug("Hup received", "age", client.Age(), "client", client.Name) - s.clientManager.Remove(client) + client.Remove() return } @@ -392,7 +399,7 @@ func (s *WSBroadcastServer) StartWithHeader(ctx context.Context, header ws.Hands s.clientManager.pool.Schedule(func() { // Ignore any messages sent from client, close on any error if _, _, err := client.Receive(ctx, s.config().ReadTimeout); err != nil { - s.clientManager.Remove(client) + client.Remove() return } }) @@ -528,7 +535,7 @@ func (s *WSBroadcastServer) Started() bool { } // Broadcast sends batch item to all clients. -func (s *WSBroadcastServer) Broadcast(bm interface{}) { +func (s *WSBroadcastServer) Broadcast(bm *m.BroadcastMessage) { s.clientManager.Broadcast(bm) }