From e198855faa58c444e7594861300ec5a3c3a877d8 Mon Sep 17 00:00:00 2001 From: Mohammad Rajabloo Date: Tue, 8 Oct 2024 19:51:38 +0330 Subject: [PATCH] refactor waitGroup with new approach for context --- waitGroup.go | 30 ++++++++++++++++++------------ waitGroup_test.go | 16 ++++++++-------- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/waitGroup.go b/waitGroup.go index 424e8a7..17a0fac 100644 --- a/waitGroup.go +++ b/waitGroup.go @@ -18,16 +18,10 @@ type WaitGroup struct { // NewWaitGroup create new WaitGroup. func NewWaitGroup(options ...WaitGroupOption) *WaitGroup { - _, wg := NewWaitGroupWithContext(context.Background(), options...) - - return wg -} - -// NewWaitGroupWithContext create new WaitGroup with custom context. -func NewWaitGroupWithContext(ctx context.Context, options ...WaitGroupOption) (context.Context, *WaitGroup) { ops := &WaitGroupOptions{ Wg: &sync.WaitGroup{}, TaskRunner: func(task func()) { go task() }, + Ctx: context.Background(), } for _, op := range options { @@ -39,9 +33,9 @@ func NewWaitGroupWithContext(ctx context.Context, options ...WaitGroupOption) (c gch = make(chan struct{}, ops.TaskLimit) } - ctx, cancel := context.WithCancelCause(context.Background()) + ctx, cancel := context.WithCancelCause(ops.Ctx) - return ctx, &WaitGroup{options: ops, gch: gch, ctx: ctx, cancel: cancel} + return &WaitGroup{options: ops, gch: gch, ctx: ctx, cancel: cancel} } // Context of current waitGroup. @@ -87,8 +81,12 @@ func (g *WaitGroup) Done(err error) { g.errors.Add(err) } -// Do calls the given function in a new goroutine. -func (g *WaitGroup) Do(f func() error) { +// Do run the given function in a new goroutine. +func (g *WaitGroup) Do(ctx context.Context, f func(ctx context.Context) error) { + if ctx == nil { + ctx = g.ctx + } + if g.gch != nil { g.gch <- struct{}{} } @@ -96,7 +94,7 @@ func (g *WaitGroup) Do(f func() error) { g.Add(1) g.options.TaskRunner(func() { - g.Done(f()) + g.Done(f(ctx)) if g.gch != nil { <-g.gch @@ -117,6 +115,7 @@ type WaitGroupOptions struct { TaskLimit int TaskRunner WaitGroupTaskRunner StopOnError bool + Ctx context.Context } type WaitGroupOption func(group *WaitGroupOptions) @@ -151,3 +150,10 @@ func WaitGroupWithStopOnError() WaitGroupOption { g.StopOnError = true } } + +// WaitGroupWithContext if you want to pass your context. +func WaitGroupWithContext(ctx context.Context) WaitGroupOption { + return func(g *WaitGroupOptions) { + g.Ctx = ctx + } +} diff --git a/waitGroup_test.go b/waitGroup_test.go index 16eadf5..78cccfe 100644 --- a/waitGroup_test.go +++ b/waitGroup_test.go @@ -85,9 +85,9 @@ func TestGroup(t *testing.T) { wg := NewWaitGroup() - wg.Do(func() error { return error1 }) + wg.Do(context.Background(), func(ctx context.Context) error { return error1 }) - wg.Do(func() error { return nil }) + wg.Do(context.Background(), func(ctx context.Context) error { return nil }) err := wg.Wait() @@ -100,7 +100,7 @@ func TestGroup(t *testing.T) { limitCount := 1 wg := NewWaitGroup(WaitGroupWithTaskLimit(limitCount)) - wg.Do(func() error { + wg.Do(context.Background(), func(ctx context.Context) error { // wait for assertion to do. time.Sleep(100 * time.Millisecond) return nil @@ -124,7 +124,7 @@ func TestGroup(t *testing.T) { wg := NewWaitGroup(WaitGroupWithTaskRunner(runner)) - wg.Do(func() error { + wg.Do(context.Background(), func(ctx context.Context) error { return nil }) @@ -136,12 +136,12 @@ func TestGroup(t *testing.T) { t.Run("set StopOnError options", func(t *testing.T) { error1 := errors.New("error 1") - ctx, wg := NewWaitGroupWithContext(context.Background(), WaitGroupWithStopOnError()) + wg := NewWaitGroup(WaitGroupWithStopOnError(), WaitGroupWithContext(context.Background())) - wg.Do(func() error { return error1 }) + wg.Do(nil, func(ctx context.Context) error { return error1 }) // sample long-running and context aware task. - wg.Do(func() error { + wg.Do(nil, func(ctx context.Context) error { for { select { case <-ctx.Done(): @@ -154,7 +154,7 @@ func TestGroup(t *testing.T) { expected := NewMultiError(error1, context.Canceled) assert.ElementsMatch(t, expected.errors, err.(*MultiError).errors) - assert.Equal(t, ctx, wg.Context()) + assert.NotNil(t, wg.Context()) }) }