Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add batcher to batch tasks #3

Merged
merged 2 commits into from
Dec 29, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 218 additions & 0 deletions batcher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
package kutils

import (
"context"
"math"
"runtime"
"runtime/debug"
"sync/atomic"
"time"

"github.com/KyberNetwork/logger"
"github.com/pkg/errors"
)

//go:generate mockgen -source=batcher.go -destination mocks/mocks.go -package mocks

var (
ErrBatcherClosed = errors.New("batcher closed")
)

// BatchableTask represents a batchable task
type BatchableTask[R any] interface {
Ctx() context.Context // The context of this task
Done() <-chan struct{} // Signals if this task was already resolved
IsDone() bool // Checks (non-blocking) if this task was already resolved
Result() (R, error) // Blocks until this task is resolved and returns result and error
Resolve(ret R, err error) // Resolves this task with return value and error
}

// ChanTask uses a done channel to signal resolution of return value and error
type ChanTask[R any] struct {
ctx context.Context
done chan struct{}
Ret R
Err error
}

func NewChanTask[R any](ctx context.Context) *ChanTask[R] {
if ctx == nil {
ctx = context.Background()
}
return &ChanTask[R]{
ctx: ctx,
done: make(chan struct{}),
}
}

func (c *ChanTask[R]) Ctx() context.Context {
return c.ctx
}

func (c *ChanTask[R]) Done() <-chan struct{} {
return c.done
}

func (c *ChanTask[R]) IsDone() bool {
select {
case <-c.done:
return true
default:
return false
}
}

func (c *ChanTask[R]) Result() (R, error) {
if c.IsDone() {
return c.Ret, c.Err
}
select {
case <-c.done:
return c.Ret, c.Err
case <-c.ctx.Done():
return *new(R), c.ctx.Err()
}
}

func (c *ChanTask[R]) Resolve(ret R, err error) {
select {
case <-c.done:
logger.Errorf("ChanTask.Resolve|called twice, ignored|c.Ret=%v,c.Err=%v|Ret=%v,Err=%v", c.Ret, c.Err, ret, err)
default:
c.Ret, c.Err = ret, err
close(c.done)
}
}

// Batcher batches together n BatchableTask's together and executes a logic for a batch of BatchableTask's.
// It skips BatchableTask's with cancelled Ctx and resolve those tasks with the context's error.
// Batch logic execution should signal each BatchableTask as done by using its Resolve method.
type Batcher[T BatchableTask[R], R any] interface {
// Batch submits a BatchableTask to the batcher.
Batch(task T)
// Close should stop Batch from being called and clean up any background resources.
Close()
}

// BatchCfg provides batchRate and batchCnt configs for a ChanBatcher. ChanBatcher will trigger a batch processing
// either if no more task is queued after batchRate, or batchCnt BatchableTask's are already queued.
type BatchCfg func() (batchRate time.Duration, batchCnt int)

// BatchFn is called for a batch of tasks collected and triggered by a ChanBatcher per its batchCfg.
type BatchFn[T any] func([]T)

// ChanBatcher implements Batcher using golang channel.
type ChanBatcher[T BatchableTask[R], R any] struct {
batchCfg BatchCfg
batchFn BatchFn[T]
taskCh chan T
closed atomic.Bool
}

func NewChanBatcher[T BatchableTask[R], R any](batchCfg BatchCfg, batchFn BatchFn[T]) *ChanBatcher[T, R] {
_, batchCnt := batchCfg()
chanBatcher := &ChanBatcher[T, R]{
batchCfg: batchCfg,
batchFn: batchFn,
taskCh: make(chan T, 16*batchCnt),
}
go chanBatcher.worker()
return chanBatcher
}

// Batch submits a BatchableTask to the channel if this chanBatcher hasn't been closed.
func (b *ChanBatcher[T, R]) Batch(task T) {
if !b.closed.Load() {
b.taskCh <- task
} else {
task.Resolve(*new(R), ErrBatcherClosed)
}
}

// Close closes this chanBatcher to prevents Batch-ing new BatchableTask's and tell the worker goroutine to finish up.
func (b *ChanBatcher[_, _]) Close() {
if !b.closed.Swap(true) {
close(b.taskCh)
}
}

// goBatchFn
func (b *ChanBatcher[T, R]) batchFnWithRecover(tasks []T) {
defer func() {
p := recover()
if p == nil {
return
}
logger.Errorf("ChanBatcher.goBatchFn|recovered from panic: %v\n%s", p, string(debug.Stack()))
var ret R
for _, task := range tasks {
NgoKimPhu marked this conversation as resolved.
Show resolved Hide resolved
if task.IsDone() {
continue
}
if err, ok := p.(error); ok {
task.Resolve(ret, errors.Wrap(err, "batchFn panicked"))
} else {
task.Resolve(ret, errors.Errorf("batchFn panicked: %v", p))
}
}
}()
b.batchFn(tasks)
}

// worker batches up BatchableTask's in taskCh per batchCfg (per at most batchRate ns and at most batchCnt BatchableTask's)
// and triggers batchFn with each batch.
func (b *ChanBatcher[T, R]) worker() {
defer func() {
if p := recover(); p != nil {
logger.Errorf("ChanBatcher.worker|recovered from panic: %v\n%s", p, string(debug.Stack()))
}
}()
var tasks []T
batchTimer := time.NewTimer(time.Duration(math.MaxInt64))
for {
runtime.Gosched() // in case GOMAXPROCS is 1, we need to cooperatively yield
select {
case <-batchTimer.C:
if len(tasks) == 0 {
break
}
logger.Debugf("ChanBatcher.worker|timer|%d tasks", len(tasks))
go b.batchFnWithRecover(tasks)
tasks = tasks[:0:0]
case task, ok := <-b.taskCh:
if !ok {
logger.Debugf("ChanBatcher.worker|closed|%d tasks", len(tasks))
if len(tasks) > 0 {
go b.batchFnWithRecover(tasks)
}
return
}
if !task.IsDone() {
select {
case <-task.Ctx().Done():
logger.Infof("ChanBatcher.worker|skip|task=%v", task)
task.Resolve(*new(R), task.Ctx().Err())
continue
default:
}
}
duration, batchCount := b.batchCfg()
if len(tasks) == 0 {
logger.Debugf("ChanBatcher.worker|timer start|duration=%s", duration)
if !batchTimer.Stop() {
select {
case <-batchTimer.C:
default:
}
}
batchTimer.Reset(duration)
}
tasks = append(tasks, task)
if len(tasks) >= batchCount {
logger.Debugf("ChanBatcher.worker|max|%d tasks", len(tasks))
go b.batchFnWithRecover(tasks)
tasks = tasks[:0:0]
}
}
}
}
143 changes: 143 additions & 0 deletions batcher_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
package kutils

import (
"context"
"runtime"
"sync/atomic"
"testing"
"time"

"github.com/KyberNetwork/logger"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
)

func TestChanBatcher(t *testing.T) {
ctx := context.Background()
batchRate := 10 * time.Millisecond
batchFn := func(_ []*ChanTask[time.Duration]) {}
batcher := NewChanBatcher[*ChanTask[time.Duration], time.Duration](func() (time.Duration, int) {
return batchRate, 2
}, func(tasks []*ChanTask[time.Duration]) { batchFn(tasks) })
var cnt atomic.Uint32
start := time.Now()
batchFn = func(tasks []*ChanTask[time.Duration]) {
cnt.Add(1)
for _, task := range tasks {
task.Resolve(time.Since(start), nil)
}
}
task0 := NewChanTask[time.Duration](ctx)
task1 := NewChanTask[time.Duration](ctx)
task2 := NewChanTask[time.Duration](ctx)

t.Run("happy", func(t *testing.T) {
batcher.Batch(task0)
batcher.Batch(task1)
_, _ = task0.Result()
assert.EqualValues(t, 1, cnt.Load())
assert.NoError(t, task0.Err)
assert.Less(t, task0.Ret, batchRate)
ret, err := task1.Result()
assert.NoError(t, err)
assert.Less(t, ret, batchRate)
time.Sleep(batchRate * 11 / 10)
runtime.Gosched()

batcher.Batch(task2)
assert.False(t, task2.IsDone())
ret, err = task2.Result()
assert.True(t, task2.IsDone())
assert.EqualValues(t, 2, cnt.Load())
assert.Equal(t, task2.Err, err)
assert.NoError(t, task2.Err)
assert.Equal(t, task2.Ret, ret)
assert.Greater(t, ret, batchRate)
})

t.Run("spam", func(t *testing.T) {
batcher := NewChanBatcher[*ChanTask[int], int](func() (time.Duration, int) { return 0, 0 },
func(tasks []*ChanTask[int]) {
for _, task := range tasks {
task.Resolve(0, nil)
}
})
const taskCnt = 1000
tasks := make([]*ChanTask[int], taskCnt)
start := time.Now()
for i := 0; i < taskCnt; i++ {
tasks[i] = NewChanTask[int](ctx)
batcher.Batch(tasks[i])
}
// 1k: 2.561804ms; 1M: 2.62s - average overhead per task = 2.6µs
logger.Warnf("done %d tasks in %v", taskCnt, time.Since(start))
for i := 0; i < taskCnt; i++ {
ret, err := tasks[i].Result()
assert.NoError(t, err)
assert.EqualValues(t, 0, ret)
}
batcher.Close()
})

t.Run("resolve twice", func(t *testing.T) {
task0.Resolve(batchRate, nil)
assert.NoError(t, task0.Err)
assert.Less(t, task0.Ret, batchRate)
})

t.Run("recover from panic", func(t *testing.T) {
oldBatchFn := batchFn
batchFn = func(tasks []*ChanTask[time.Duration]) {
panic("test panic")
}
task0 = NewChanTask[time.Duration](ctx)
task1 = NewChanTask[time.Duration](ctx)
task0.Resolve(0, nil)
batcher.Batch(task0)
batcher.Batch(task1)
<-task1.Done()
assert.ErrorContains(t, task1.Err, "test panic")

panicErr := errors.New("test panic error")
batchFn = func(tasks []*ChanTask[time.Duration]) {
panic(panicErr)
}
task0 = NewChanTask[time.Duration](ctx)
task1 = NewChanTask[time.Duration](ctx)
batcher.Batch(task0)
batcher.Batch(task1)
<-task1.Done()
assert.ErrorIs(t, task0.Err, panicErr)
assert.ErrorIs(t, task1.Err, panicErr)

batchFn = oldBatchFn
task2 = NewChanTask[time.Duration](nil) // nolint:staticcheck
batcher.Batch(task2)
batcher.Batch(task2)
ret, err := task2.Result()
assert.NoError(t, err)
assert.Greater(t, ret, batchRate)
})

t.Run("cancelled task", func(t *testing.T) {
ctx, cancel := context.WithCancel(ctx)
task0 = NewChanTask[time.Duration](ctx)
batcher.Batch(task0)
cancel()
_, err := task0.Result()
assert.ErrorIs(t, err, context.Canceled)
})

t.Run("close", func(t *testing.T) {
batcher.Batch(task2)
batcher.Close()
task3 := NewChanTask[time.Duration](ctx)
batcher.Batch(task3)
assert.ErrorIs(t, task3.Err, ErrBatcherClosed)
})

t.Run("invalid task", func(t *testing.T) {
NewChanBatcher[*ChanTask[int], int](func() (time.Duration, int) { return 0, 0 },
nil).Batch(&ChanTask[int]{})
})
}
Loading