diff --git a/batcher.go b/batcher.go index f045605..6960bb4 100644 --- a/batcher.go +++ b/batcher.go @@ -92,6 +92,8 @@ func (c *ChanTask[R]) Resolve(ret R, err error) { type Batcher[T BatchableTask[R], R any] interface { // Batch submits a BatchableTask to the batcher. Batch(task T) + // Flush executes tasks currently waiting in queue immediately. + Flush() // Close should stop Batch from being called and clean up any background resources. Close() } @@ -108,6 +110,7 @@ type ChanBatcher[T BatchableTask[R], R any] struct { batchCfg BatchCfg batchFn BatchFn[T] taskCh chan T + flushCh chan struct{} closed atomic.Bool } @@ -117,6 +120,7 @@ func NewChanBatcher[T BatchableTask[R], R any](batchCfg BatchCfg, batchFn BatchF batchCfg: batchCfg, batchFn: batchFn, taskCh: make(chan T, 16*batchCnt), + flushCh: make(chan struct{}, 1), } go chanBatcher.worker() return chanBatcher @@ -131,6 +135,16 @@ func (b *ChanBatcher[T, R]) Batch(task T) { } } +// Flush executes tasks currently waiting in queue immediately. +func (b *ChanBatcher[T, R]) Flush() { + if !b.closed.Load() { + select { + case b.flushCh <- struct{}{}: + default: + } + } +} + // Close closes this chanBatcher to prevents Batch-ing new BatchableTask's and tell the worker goroutine to finish up. func (b *ChanBatcher[_, _]) Close() { if !b.closed.Swap(true) { @@ -185,6 +199,13 @@ func (b *ChanBatcher[T, R]) worker() { klog.Debugf(tasks[0].Ctx(), "ChanBatcher.worker|timer|%d tasks", len(tasks)) go b.batchFnWithRecover(tasks) tasks = tasks[:0:0] + case <-b.flushCh: + if len(tasks) == 0 { + break + } + klog.Debugf(tasks[0].Ctx(), "ChanBatcher.worker|flush|%d tasks", len(tasks)) + go b.batchFnWithRecover(tasks) + tasks = tasks[:0:0] case task, ok := <-b.taskCh: if !ok { ctx := context.Background() diff --git a/batcher_test.go b/batcher_test.go index 8368aab..52e3b5b 100644 --- a/batcher_test.go +++ b/batcher_test.go @@ -21,7 +21,7 @@ func TestChanBatcher(t *testing.T) { return batchRate, 2 }, func(tasks []*ChanTask[time.Duration]) { batchFn(tasks) }) var cnt atomic.Uint32 - start := time.Now() + var start time.Time batchFn = func(tasks []*ChanTask[time.Duration]) { cnt.Add(1) for _, task := range tasks { @@ -31,29 +31,54 @@ func TestChanBatcher(t *testing.T) { task0 := NewChanTask[time.Duration](ctx) task1 := NewChanTask[time.Duration](ctx) task2 := NewChanTask[time.Duration](ctx) + task3 := NewChanTask[time.Duration](ctx) t.Run("happy", func(t *testing.T) { - batcher.Batch(task0) - batcher.Batch(task1) - _, _ = task0.Result() - assert.EqualValues(t, 1, cnt.Load()) - assert.NoError(t, task0.Err) - assert.Less(t, task0.Ret, batchRate) - ret, err := task1.Result() - assert.NoError(t, err) - assert.Less(t, ret, batchRate) - time.Sleep(batchRate * 11 / 10) - runtime.Gosched() + t.Run("trigger max", func(t *testing.T) { + start = time.Now() + batcher.Batch(task0) + batcher.Batch(task1) + _, _ = task0.Result() + assert.EqualValues(t, 1, cnt.Load()) + assert.NoError(t, task0.Err) + assert.Less(t, task0.Ret, batchRate) + ret, err := task1.Result() + assert.NoError(t, err) + assert.Less(t, ret, batchRate) + time.Sleep(batchRate * 11 / 10) + runtime.Gosched() + }) - batcher.Batch(task2) - assert.False(t, task2.IsDone()) - ret, err = task2.Result() - assert.True(t, task2.IsDone()) - assert.EqualValues(t, 2, cnt.Load()) - assert.Equal(t, task2.Err, err) - assert.NoError(t, task2.Err) - assert.Equal(t, task2.Ret, ret) - assert.Greater(t, ret, batchRate) + t.Run("trigger timer after blocked by .Result()", func(t *testing.T) { + start = time.Now() + batcher.Batch(task2) + assert.False(t, task2.IsDone()) + ret, err := task2.Result() + assert.True(t, task2.IsDone()) + assert.EqualValues(t, 2, cnt.Load()) + assert.Equal(t, task2.Err, err) + assert.NoError(t, task2.Err) + assert.Equal(t, task2.Ret, ret) + assert.Greater(t, ret, batchRate) + }) + + t.Run("trigger flush", func(t *testing.T) { + start = time.Now() + batcher.Batch(task3) + assert.False(t, task3.IsDone()) + batcher.Flush() + batcher.Flush() + ret, err := task3.Result() + assert.True(t, task3.IsDone()) + assert.EqualValues(t, 3, cnt.Load()) + assert.Equal(t, task3.Err, err) + assert.NoError(t, task3.Err) + assert.Equal(t, task3.Ret, ret) + assert.Less(t, ret, batchRate) + batcher.Flush() + batcher.Flush() + assert.EqualValues(t, 3, cnt.Load()) + }) }) t.Run("spam", func(t *testing.T) { @@ -111,13 +136,14 @@ func TestChanBatcher(t *testing.T) { assert.ErrorIs(t, task0.Err, panicErr) assert.ErrorIs(t, task1.Err, panicErr) + start = time.Now() batchFn = oldBatchFn task2 = NewChanTask[time.Duration](nil) // nolint:staticcheck batcher.Batch(task2) batcher.Batch(task2) ret, err := task2.Result() assert.NoError(t, err) - assert.Greater(t, ret, batchRate) + assert.Less(t, ret, batchRate) }) t.Run("cancelled task", func(t *testing.T) {