diff --git a/pkg/protocol/engine/congestioncontrol/scheduler/drr/basicbuffer.go b/pkg/protocol/engine/congestioncontrol/scheduler/drr/basicbuffer.go index fabd1e72c..f90a4458b 100644 --- a/pkg/protocol/engine/congestioncontrol/scheduler/drr/basicbuffer.go +++ b/pkg/protocol/engine/congestioncontrol/scheduler/drr/basicbuffer.go @@ -2,6 +2,7 @@ package drr import ( "container/ring" + "fmt" "math" "time" @@ -20,8 +21,9 @@ import ( type BasicBuffer struct { activeIssuers *shrinkingmap.ShrinkingMap[iotago.AccountID, *ring.Ring] ring *ring.Ring - // size is the number of blocks in the buffer. - size atomic.Int64 + + readyBlocksCount atomic.Int64 + totalBlocksCount atomic.Int64 tokenBucket float64 lastScheduleTime time.Time @@ -57,11 +59,6 @@ func (b *BasicBuffer) Clear() { }) } -// Size returns the total number of blocks in BasicBuffer. -func (b *BasicBuffer) Size() int { - return int(b.size.Load()) -} - // IssuerQueue returns the queue for the corresponding issuer. func (b *BasicBuffer) IssuerQueue(issuerID iotago.AccountID) *IssuerQueue { element, exists := b.activeIssuers.Get(issuerID) @@ -97,8 +94,25 @@ func (b *BasicBuffer) IssuerQueueBlockCount(issuerID iotago.AccountID) int { } func (b *BasicBuffer) CreateIssuerQueue(issuerID iotago.AccountID) *IssuerQueue { - issuerQueue := NewIssuerQueue(issuerID) - b.activeIssuers.Set(issuerID, b.ringInsert(issuerQueue)) + element := b.activeIssuers.Compute(issuerID, func(_ *ring.Ring, exists bool) *ring.Ring { + if exists { + panic(fmt.Sprintf("issuer queue already exists: %s", issuerID.String())) + } + + return b.ringInsert(NewIssuerQueue(issuerID, func(totalSizeDelta int64, readySizeDelta int64) { + if totalSizeDelta != 0 { + b.totalBlocksCount.Add(totalSizeDelta) + } + if readySizeDelta != 0 { + b.readyBlocksCount.Add(readySizeDelta) + } + })) + }) + + issuerQueue, isIQ := element.Value.(*IssuerQueue) + if !isIQ { + panic("buffer contains elements that are not issuer queues") + } return issuerQueue } @@ -127,7 +141,7 @@ func (b *BasicBuffer) RemoveIssuerQueue(issuerID iotago.AccountID) { if !isIQ { panic("buffer contains elements that are not issuer queues") } - b.size.Sub(int64(issuerQueue.Size())) + issuerQueue.Clear() b.ringRemove(element) b.activeIssuers.Delete(issuerID) @@ -158,10 +172,8 @@ func (b *BasicBuffer) Submit(blk *blocks.Block, issuerQueue *IssuerQueue, quantu return nil, false } - b.size.Inc() - // if max buffer size exceeded, drop from tail of the longest mana-scaled queue - if b.Size() > maxBuffer { + if b.TotalBlocksCount() > maxBuffer { return b.dropTail(quantumFunc, maxBuffer), true } @@ -178,40 +190,14 @@ func (b *BasicBuffer) Ready(block *blocks.Block) bool { return issuerQueue.Ready(block) } -// ReadyBlocksCount returns the number of ready blocks in the buffer. -func (b *BasicBuffer) ReadyBlocksCount() (readyBlocksCount int) { - start := b.Current() - if start == nil { - return - } - - for q := start; ; { - readyBlocksCount += q.readyHeap.Len() - q = b.Next() - if q == start { - break - } - } - - return -} - // TotalBlocksCount returns the number of blocks in the buffer. func (b *BasicBuffer) TotalBlocksCount() (blocksCount int) { - start := b.Current() - if start == nil { - return - } - for q := start; ; { - blocksCount += q.readyHeap.Len() - blocksCount += q.nonReadyMap.Size() - q = b.Next() - if q == start { - break - } - } + return int(b.totalBlocksCount.Load()) +} - return +// ReadyBlocksCount returns the number of ready blocks in the buffer. +func (b *BasicBuffer) ReadyBlocksCount() (readyBlocksCount int) { + return int(b.readyBlocksCount.Load()) } // Next returns the next IssuerQueue in round-robin order. @@ -250,8 +236,6 @@ func (b *BasicBuffer) PopFront() *blocks.Block { return nil } - b.size.Dec() - return block } @@ -275,7 +259,7 @@ func (b *BasicBuffer) IssuerIDs() []iotago.AccountID { func (b *BasicBuffer) dropTail(quantumFunc func(iotago.AccountID) Deficit, maxBuffer int) (droppedBlocks []*blocks.Block) { // remove as many blocks as necessary to stay within max buffer size - for b.Size() > maxBuffer { + for b.TotalBlocksCount() > maxBuffer { // find the longest mana-scaled queue maxIssuerID := b.mustLongestQueueIssuerID(quantumFunc) longestQueue := b.IssuerQueue(maxIssuerID) @@ -288,7 +272,6 @@ func (b *BasicBuffer) dropTail(quantumFunc func(iotago.AccountID) Deficit, maxBu panic("buffer is full, but tail of longest queue does not exist") } - b.size.Dec() droppedBlocks = append(droppedBlocks, tail) } diff --git a/pkg/protocol/engine/congestioncontrol/scheduler/drr/issuerqueue.go b/pkg/protocol/engine/congestioncontrol/scheduler/drr/issuerqueue.go index 68113309b..7cde0038c 100644 --- a/pkg/protocol/engine/congestioncontrol/scheduler/drr/issuerqueue.go +++ b/pkg/protocol/engine/congestioncontrol/scheduler/drr/issuerqueue.go @@ -18,19 +18,37 @@ import ( // IssuerQueue keeps the submitted blocks of an issuer. type IssuerQueue struct { - issuerID iotago.AccountID + issuerID iotago.AccountID + sizeChangedFunc func(totalSizeDelta int64, readySizeDelta int64, workDelta int64) + nonReadyMap *shrinkingmap.ShrinkingMap[iotago.BlockID, *blocks.Block] readyHeap generalheap.Heap[timed.HeapKey, *blocks.Block] - size atomic.Int64 - work atomic.Int64 + + size atomic.Int64 + work atomic.Int64 } // NewIssuerQueue returns a new IssuerQueue. -func NewIssuerQueue(issuerID iotago.AccountID) *IssuerQueue { - return &IssuerQueue{ +func NewIssuerQueue(issuerID iotago.AccountID, sizeChangedCallback func(totalSizeDelta int64, readySizeDelta int64)) *IssuerQueue { + queue := &IssuerQueue{ issuerID: issuerID, nonReadyMap: shrinkingmap.New[iotago.BlockID, *blocks.Block](), } + + queue.sizeChangedFunc = func(totalSizeDelta int64, readySizeDelta int64, workDelta int64) { + if totalSizeDelta != 0 { + queue.size.Add(totalSizeDelta) + } + if workDelta != 0 { + queue.work.Add(workDelta) + } + + if sizeChangedCallback != nil { + sizeChangedCallback(totalSizeDelta, readySizeDelta) + } + } + + return queue } // Size returns the total number of blocks in the queue. @@ -70,21 +88,19 @@ func (q *IssuerQueue) Submit(element *blocks.Block) bool { } q.nonReadyMap.Set(element.ID(), element) - q.size.Inc() - q.work.Add(int64(element.WorkScore())) + q.sizeChangedFunc(1, 0, int64(element.WorkScore())) return true } -// Unsubmit removes a previously submitted block from the queue. -func (q *IssuerQueue) Unsubmit(block *blocks.Block) bool { +// unsubmit removes a previously submitted block from the queue. +func (q *IssuerQueue) unsubmit(block *blocks.Block) bool { if _, submitted := q.nonReadyMap.Get(block.ID()); !submitted { return false } q.nonReadyMap.Delete(block.ID()) - q.size.Dec() - q.work.Sub(int64(block.WorkScore())) + q.sizeChangedFunc(-1, 0, -int64(block.WorkScore())) return true } @@ -98,6 +114,8 @@ func (q *IssuerQueue) Ready(block *blocks.Block) bool { q.nonReadyMap.Delete(block.ID()) heap.Push(&q.readyHeap, &generalheap.HeapElement[timed.HeapKey, *blocks.Block]{Value: block, Key: timed.HeapKey(block.IssuingTime())}) + q.sizeChangedFunc(0, 1, 0) + return true } @@ -112,6 +130,18 @@ func (q *IssuerQueue) IDs() (ids []iotago.BlockID) { return ids } +// Clear removes all blocks from the queue. +func (q *IssuerQueue) Clear() { + readyBlocksCount := int64(q.readyHeap.Len()) + + q.nonReadyMap.Clear() + for q.readyHeap.Len() > 0 { + _ = q.readyHeap.Pop() + } + + q.sizeChangedFunc(-int64(q.Size()), -readyBlocksCount, -int64(q.Work())) +} + // Front returns the first ready block in the queue. func (q *IssuerQueue) Front() *blocks.Block { if q == nil || q.readyHeap.Len() == 0 { @@ -132,8 +162,7 @@ func (q *IssuerQueue) PopFront() *blocks.Block { panic("unable to pop from a non-empty heap.") } blk := heapElement.Value - q.size.Dec() - q.work.Sub(int64(blk.WorkScore())) + q.sizeChangedFunc(-1, -1, -int64(blk.WorkScore())) return blk } @@ -152,7 +181,7 @@ func (q *IssuerQueue) RemoveTail() *blocks.Block { heapTailIndex := q.heapTail() // if heap tail (oldest ready block) does not exist or is newer than oldest non-ready block, unsubmit the oldest non-ready block if oldestNonReadyBlock != nil && (heapTailIndex < 0 || q.readyHeap[heapTailIndex].Key.CompareTo(timed.HeapKey(oldestNonReadyBlock.IssuingTime())) > 0) { - if q.Unsubmit(oldestNonReadyBlock) { + if q.unsubmit(oldestNonReadyBlock) { return oldestNonReadyBlock } } else if heapTailIndex < 0 { // the heap is empty @@ -166,8 +195,7 @@ func (q *IssuerQueue) RemoveTail() *blocks.Block { panic("trying to remove a heap element that does not exist.") } blk := heapElement.Value - q.size.Dec() - q.work.Sub(int64(blk.WorkScore())) + q.sizeChangedFunc(-1, -1, -int64(blk.WorkScore())) return blk } diff --git a/pkg/protocol/engine/congestioncontrol/scheduler/drr/scheduler.go b/pkg/protocol/engine/congestioncontrol/scheduler/drr/scheduler.go index 4239d831a..d9dcf8412 100644 --- a/pkg/protocol/engine/congestioncontrol/scheduler/drr/scheduler.go +++ b/pkg/protocol/engine/congestioncontrol/scheduler/drr/scheduler.go @@ -1,6 +1,7 @@ package drr import ( + "fmt" "math" "sync" "time" @@ -78,15 +79,9 @@ func NewProvider(opts ...options.Option[Scheduler]) module.Provider[*engine.Engi return } - s.validatorBuffer.buffer.ForEach(func(accountID iotago.AccountID, validatorQueue *ValidatorQueue) bool { - if !committee.HasAccount(accountID) { - s.shutdownValidatorQueue(validatorQueue) - } - - return true + s.validatorBuffer.Delete(func(validatorQueue *ValidatorQueue) bool { + return !committee.HasAccount(validatorQueue.AccountID()) }) - - s.validatorBuffer.Clear() } }) e.Ledger.InitializedEvent().OnTrigger(func() { @@ -150,12 +145,6 @@ func (s *Scheduler) shutdown() { s.bufferMutex.Lock() defer s.bufferMutex.Unlock() - // validator workers need to be shut down first, otherwise they will hang on the shutdown channel. - s.validatorBuffer.buffer.ForEach(func(_ iotago.AccountID, validatorQueue *ValidatorQueue) bool { - s.shutdownValidatorQueue(validatorQueue) - - return true - }) s.validatorBuffer.Clear() close(s.shutdownSignal) @@ -168,6 +157,7 @@ func (s *Scheduler) shutdown() { // Start starts the scheduler. func (s *Scheduler) Start() { s.shutdownSignal = make(chan struct{}, 1) + s.workersWg.Add(1) go s.basicBlockLoop() @@ -196,7 +186,7 @@ func (s *Scheduler) ValidatorQueueBlockCount(issuerID iotago.AccountID) int { // BasicBufferSize returns the current buffer size of the Scheduler as block count. func (s *Scheduler) BasicBufferSize() int { - return s.basicBuffer.Size() + return s.basicBuffer.TotalBlocksCount() } func (s *Scheduler) ValidatorBufferSize() int { @@ -221,7 +211,7 @@ func (s *Scheduler) IsBlockIssuerReady(accountID iotago.AccountID, workScores .. defer s.bufferMutex.RUnlock() // if the buffer is completely empty, any issuer can issue a block. - if s.basicBuffer.Size() == 0 { + if s.basicBuffer.TotalBlocksCount() == 0 { return true } @@ -256,13 +246,6 @@ func (s *Scheduler) Reset() { s.bufferMutex.Lock() defer s.bufferMutex.Unlock() - // Validator workers need to be signaled to exit. - s.validatorBuffer.buffer.ForEach(func(_ iotago.AccountID, validatorQueue *ValidatorQueue) bool { - s.shutdownValidatorQueue(validatorQueue) - - return true - }) - s.basicBuffer.Clear() s.validatorBuffer.Clear() } @@ -309,11 +292,7 @@ func (s *Scheduler) enqueueValidationBlock(block *blocks.Block) { s.bufferMutex.Lock() defer s.bufferMutex.Unlock() - _, exists := s.validatorBuffer.Get(block.IssuerID()) - if !exists { - s.addValidator(block.IssuerID()) - } - droppedBlock, submitted := s.validatorBuffer.Submit(block, int(s.apiProvider.CommittedAPI().ProtocolParameters().CongestionControlParameters().MaxValidationBufferSize)) + droppedBlock, submitted := s.getOrCreateValidatorQueue(block.IssuerID()).Submit(block, int(s.apiProvider.CommittedAPI().ProtocolParameters().CongestionControlParameters().MaxValidationBufferSize)) if !submitted { return } @@ -408,29 +387,12 @@ func (s *Scheduler) selectBlockToScheduleWithLocking() { s.bufferMutex.Lock() defer s.bufferMutex.Unlock() - s.validatorBuffer.buffer.ForEach(func(_ iotago.AccountID, validatorQueue *ValidatorQueue) bool { - if s.selectValidationBlockWithoutLocking(validatorQueue) { - s.validatorBuffer.size.Dec() - } - + s.validatorBuffer.ForEachValidatorQueue(func(_ iotago.AccountID, validatorQueue *ValidatorQueue) bool { + validatorQueue.ScheduleNext() return true }) - s.selectBasicBlockWithoutLocking() -} - -func (s *Scheduler) selectValidationBlockWithoutLocking(validatorQueue *ValidatorQueue) bool { - // already a block selected to be scheduled. - if len(validatorQueue.blockChan) > 0 { - return false - } - - if blockToSchedule := validatorQueue.PopFront(); blockToSchedule != nil { - validatorQueue.blockChan <- blockToSchedule - - return true - } - return false + s.selectBasicBlockWithoutLocking() } func (s *Scheduler) selectBasicBlockWithoutLocking() { @@ -606,7 +568,14 @@ func (s *Scheduler) getOrCreateIssuer(accountID iotago.AccountID) *IssuerQueue { func (s *Scheduler) createIssuer(accountID iotago.AccountID) *IssuerQueue { issuerQueue := s.basicBuffer.CreateIssuerQueue(accountID) - s.deficits.Set(accountID, 0) + s.deficits.Compute(accountID, func(_ Deficit, exists bool) Deficit { + if exists { + panic(fmt.Sprintf("issuer already exists: %s", accountID.String())) + } + + // if the issuer is new, we need to set the deficit to 0. + return 0 + }) return issuerQueue } @@ -680,24 +649,14 @@ func (s *Scheduler) isReady(block *blocks.Block) bool { // tryReady tries to set the given block as ready. func (s *Scheduler) tryReady(block *blocks.Block) { if s.isReady(block) { - s.ready(block) + s.basicBuffer.Ready(block) } } // tryReadyValidator tries to set the given validation block as ready. func (s *Scheduler) tryReadyValidationBlock(block *blocks.Block) { if s.isReady(block) { - s.readyValidationBlock(block) - } -} - -func (s *Scheduler) ready(block *blocks.Block) { - s.basicBuffer.Ready(block) -} - -func (s *Scheduler) readyValidationBlock(block *blocks.Block) { - if validatorQueue, exists := s.validatorBuffer.Get(block.IssuerID()); exists { - validatorQueue.Ready(block) + s.validatorBuffer.Ready(block) } } @@ -737,15 +696,11 @@ func (s *Scheduler) deficitFromWork(work iotago.WorkScore) Deficit { return Deficit(work) * deficitScaleFactor } -func (s *Scheduler) addValidator(accountID iotago.AccountID) *ValidatorQueue { - validatorQueue := NewValidatorQueue(accountID) - s.validatorBuffer.Set(accountID, validatorQueue) - s.workersWg.Add(1) - go s.validatorLoop(validatorQueue) +func (s *Scheduler) getOrCreateValidatorQueue(accountID iotago.AccountID) *ValidatorQueue { + validatorQueue := s.validatorBuffer.GetOrCreate(accountID, func(queue *ValidatorQueue) { + s.workersWg.Add(1) + go s.validatorLoop(queue) + }) return validatorQueue } - -func (s *Scheduler) shutdownValidatorQueue(validatorQueue *ValidatorQueue) { - close(validatorQueue.shutdownSignal) -} diff --git a/pkg/protocol/engine/congestioncontrol/scheduler/drr/validatorbuffer.go b/pkg/protocol/engine/congestioncontrol/scheduler/drr/validatorbuffer.go new file mode 100644 index 000000000..8b68a2176 --- /dev/null +++ b/pkg/protocol/engine/congestioncontrol/scheduler/drr/validatorbuffer.go @@ -0,0 +1,83 @@ +package drr + +import ( + "go.uber.org/atomic" + + "github.com/iotaledger/hive.go/ds/shrinkingmap" + "github.com/iotaledger/iota-core/pkg/protocol/engine/blocks" + iotago "github.com/iotaledger/iota.go/v4" +) + +type ValidatorBuffer struct { + buffer *shrinkingmap.ShrinkingMap[iotago.AccountID, *ValidatorQueue] + size atomic.Int64 +} + +func NewValidatorBuffer() *ValidatorBuffer { + return &ValidatorBuffer{ + buffer: shrinkingmap.New[iotago.AccountID, *ValidatorQueue](), + } +} + +func (b *ValidatorBuffer) Size() int { + if b == nil { + return 0 + } + + return int(b.size.Load()) +} + +func (b *ValidatorBuffer) Get(accountID iotago.AccountID) (*ValidatorQueue, bool) { + return b.buffer.Get(accountID) +} + +func (b *ValidatorBuffer) GetOrCreate(accountID iotago.AccountID, onCreateCallback func(*ValidatorQueue)) *ValidatorQueue { + return b.buffer.Compute(accountID, func(currentValue *ValidatorQueue, exists bool) *ValidatorQueue { + if exists { + return currentValue + } + + queue := NewValidatorQueue(accountID, func(totalSizeDelta int64) { + b.size.Add(totalSizeDelta) + }) + if onCreateCallback != nil { + onCreateCallback(queue) + } + + return queue + }) +} + +// Ready marks a previously submitted block as ready to be scheduled. +func (b *ValidatorBuffer) Ready(block *blocks.Block) { + if validatorQueue, exists := b.Get(block.IssuerID()); exists { + validatorQueue.Ready(block) + } +} + +// ForEachValidatorQueue iterates over all validator queues. +func (b *ValidatorBuffer) ForEachValidatorQueue(consumer func(accountID iotago.AccountID, validatorQueue *ValidatorQueue) bool) { + b.buffer.ForEach(func(accountID iotago.AccountID, validatorQueue *ValidatorQueue) bool { + return consumer(accountID, validatorQueue) + }) +} + +// Delete removes all validator queues that match the predicate. +func (b *ValidatorBuffer) Delete(predicate func(element *ValidatorQueue) bool) { + b.buffer.ForEach(func(accountID iotago.AccountID, validatorQueue *ValidatorQueue) bool { + if predicate(validatorQueue) { + // validator workers need to be shut down first, otherwise they will hang on the shutdown channel. + validatorQueue.Shutdown() + b.buffer.Delete(accountID) + } + + return true + }) +} + +func (b *ValidatorBuffer) Clear() { + b.Delete(func(_ *ValidatorQueue) bool { + // remove all + return true + }) +} diff --git a/pkg/protocol/engine/congestioncontrol/scheduler/drr/validatorqueue.go b/pkg/protocol/engine/congestioncontrol/scheduler/drr/validatorqueue.go index 15f08a5bc..39aea226b 100644 --- a/pkg/protocol/engine/congestioncontrol/scheduler/drr/validatorqueue.go +++ b/pkg/protocol/engine/congestioncontrol/scheduler/drr/validatorqueue.go @@ -17,7 +17,9 @@ import ( ) type ValidatorQueue struct { - accountID iotago.AccountID + accountID iotago.AccountID + sizeChangedFunc func(totalSizeDelta int64) + submitted *shrinkingmap.ShrinkingMap[iotago.BlockID, *blocks.Block] inbox generalheap.Heap[timed.HeapKey, *blocks.Block] size atomic.Int64 @@ -29,8 +31,8 @@ type ValidatorQueue struct { shutdownSignal chan struct{} } -func NewValidatorQueue(accountID iotago.AccountID) *ValidatorQueue { - return &ValidatorQueue{ +func NewValidatorQueue(accountID iotago.AccountID, sizeChangedCallback func(totalSizeDelta int64)) *ValidatorQueue { + queue := &ValidatorQueue{ accountID: accountID, submitted: shrinkingmap.New[iotago.BlockID, *blocks.Block](), blockChan: make(chan *blocks.Block, 1), @@ -38,6 +40,15 @@ func NewValidatorQueue(accountID iotago.AccountID) *ValidatorQueue { tokenBucket: 1, lastScheduleTime: time.Now(), } + queue.sizeChangedFunc = func(totalSizeDelta int64) { + queue.size.Add(totalSizeDelta) + + if sizeChangedCallback != nil { + sizeChangedCallback(totalSizeDelta) + } + } + + return queue } func (q *ValidatorQueue) Size() int { @@ -52,6 +63,22 @@ func (q *ValidatorQueue) AccountID() iotago.AccountID { return q.accountID } +// ScheduleNext schedules the next block. +func (q *ValidatorQueue) ScheduleNext() bool { + // already a block selected to be scheduled. + if len(q.blockChan) > 0 { + return false + } + + if blockToSchedule := q.PopFront(); blockToSchedule != nil { + q.blockChan <- blockToSchedule + + return true + } + + return false +} + func (q *ValidatorQueue) Submit(block *blocks.Block, maxBuffer int) (*blocks.Block, bool) { if blkAccountID := block.IssuerID(); q.accountID != blkAccountID { panic(fmt.Sprintf("issuerqueue: queue issuer ID(%x) and issuer ID(%x) does not match.", q.accountID, blkAccountID)) @@ -62,7 +89,7 @@ func (q *ValidatorQueue) Submit(block *blocks.Block, maxBuffer int) (*blocks.Blo } q.submitted.Set(block.ID(), block) - q.size.Inc() + q.sizeChangedFunc(1) if int(q.size.Load()) > maxBuffer { return q.RemoveTail(), true @@ -71,13 +98,13 @@ func (q *ValidatorQueue) Submit(block *blocks.Block, maxBuffer int) (*blocks.Blo return nil, true } -func (q *ValidatorQueue) Unsubmit(block *blocks.Block) bool { +func (q *ValidatorQueue) unsubmit(block *blocks.Block) bool { if _, submitted := q.submitted.Get(block.ID()); !submitted { return false } q.submitted.Delete(block.ID()) - q.size.Dec() + q.sizeChangedFunc(-1) return true } @@ -104,7 +131,7 @@ func (q *ValidatorQueue) PopFront() *blocks.Block { return nil } blk := heapElement.Value - q.size.Dec() + q.sizeChangedFunc(-1) return blk } @@ -122,7 +149,7 @@ func (q *ValidatorQueue) RemoveTail() *blocks.Block { tail := q.tail() // if heap tail does not exist or tail is newer than oldest submitted block, unsubmit oldest block if oldestSubmittedBlock != nil && (tail < 0 || q.inbox[tail].Key.CompareTo(timed.HeapKey(oldestSubmittedBlock.IssuingTime())) > 0) { - q.Unsubmit(oldestSubmittedBlock) + q.unsubmit(oldestSubmittedBlock) return oldestSubmittedBlock } else if tail < 0 { @@ -136,7 +163,7 @@ func (q *ValidatorQueue) RemoveTail() *blocks.Block { return nil } blk := heapElement.Value - q.size.Dec() + q.sizeChangedFunc(-1) return blk } @@ -174,63 +201,19 @@ func (q *ValidatorQueue) deductTokens(tokens float64) { q.tokenBucket -= tokens } -type ValidatorBuffer struct { - buffer *shrinkingmap.ShrinkingMap[iotago.AccountID, *ValidatorQueue] - size atomic.Int64 -} - -func NewValidatorBuffer() *ValidatorBuffer { - return &ValidatorBuffer{ - buffer: shrinkingmap.New[iotago.AccountID, *ValidatorQueue](), +// Clear removes all blocks from the queue. +func (q *ValidatorQueue) Clear() { + q.submitted.Clear() + for q.inbox.Len() > 0 { + _ = heap.Pop(&q.inbox) } -} -func (b *ValidatorBuffer) Size() int { - if b == nil { - return 0 - } - - return int(b.size.Load()) -} - -func (b *ValidatorBuffer) Get(accountID iotago.AccountID) (*ValidatorQueue, bool) { - return b.buffer.Get(accountID) -} - -func (b *ValidatorBuffer) Set(accountID iotago.AccountID, validatorQueue *ValidatorQueue) bool { - return b.buffer.Set(accountID, validatorQueue) -} - -func (b *ValidatorBuffer) Submit(block *blocks.Block, maxBuffer int) (*blocks.Block, bool) { - validatorQueue, exists := b.buffer.Get(block.IssuerID()) - if !exists { - return nil, false - } - droppedBlock, submitted := validatorQueue.Submit(block, maxBuffer) - if submitted { - b.size.Inc() - } - if droppedBlock != nil { - b.size.Dec() - } - - return droppedBlock, submitted + q.sizeChangedFunc(-int64(q.Size())) } -func (b *ValidatorBuffer) Delete(accountID iotago.AccountID) { - validatorQueue, exists := b.buffer.Get(accountID) - if !exists { - return - } - b.size.Sub(int64(validatorQueue.Size())) +// Shutdown stops the queue and clears all blocks. +func (q ValidatorQueue) Shutdown() { + close(q.shutdownSignal) - b.buffer.Delete(accountID) -} - -func (b *ValidatorBuffer) Clear() { - b.buffer.ForEachKey(func(accountID iotago.AccountID) bool { - b.Delete(accountID) - - return true - }) + q.Clear() }