Skip to content

Commit

Permalink
refactor waitGroup with new approach for context
Browse files Browse the repository at this point in the history
  • Loading branch information
mrsoftware committed Oct 8, 2024
1 parent 064771b commit e198855
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 20 deletions.
30 changes: 18 additions & 12 deletions waitGroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,10 @@ type WaitGroup struct {

// NewWaitGroup create new WaitGroup.
func NewWaitGroup(options ...WaitGroupOption) *WaitGroup {
_, wg := NewWaitGroupWithContext(context.Background(), options...)

return wg
}

// NewWaitGroupWithContext create new WaitGroup with custom context.
func NewWaitGroupWithContext(ctx context.Context, options ...WaitGroupOption) (context.Context, *WaitGroup) {
ops := &WaitGroupOptions{
Wg: &sync.WaitGroup{},
TaskRunner: func(task func()) { go task() },
Ctx: context.Background(),
}

for _, op := range options {
Expand All @@ -39,9 +33,9 @@ func NewWaitGroupWithContext(ctx context.Context, options ...WaitGroupOption) (c
gch = make(chan struct{}, ops.TaskLimit)
}

ctx, cancel := context.WithCancelCause(context.Background())
ctx, cancel := context.WithCancelCause(ops.Ctx)

return ctx, &WaitGroup{options: ops, gch: gch, ctx: ctx, cancel: cancel}
return &WaitGroup{options: ops, gch: gch, ctx: ctx, cancel: cancel}
}

// Context of current waitGroup.
Expand Down Expand Up @@ -87,16 +81,20 @@ func (g *WaitGroup) Done(err error) {
g.errors.Add(err)
}

// Do calls the given function in a new goroutine.
func (g *WaitGroup) Do(f func() error) {
// Do run the given function in a new goroutine.
func (g *WaitGroup) Do(ctx context.Context, f func(ctx context.Context) error) {
if ctx == nil {
ctx = g.ctx
}

if g.gch != nil {
g.gch <- struct{}{}
}

g.Add(1)

g.options.TaskRunner(func() {
g.Done(f())
g.Done(f(ctx))

if g.gch != nil {
<-g.gch
Expand All @@ -117,6 +115,7 @@ type WaitGroupOptions struct {
TaskLimit int
TaskRunner WaitGroupTaskRunner
StopOnError bool
Ctx context.Context
}

type WaitGroupOption func(group *WaitGroupOptions)
Expand Down Expand Up @@ -151,3 +150,10 @@ func WaitGroupWithStopOnError() WaitGroupOption {
g.StopOnError = true
}
}

// WaitGroupWithContext if you want to pass your context.
func WaitGroupWithContext(ctx context.Context) WaitGroupOption {
return func(g *WaitGroupOptions) {
g.Ctx = ctx
}
}
16 changes: 8 additions & 8 deletions waitGroup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ func TestGroup(t *testing.T) {

wg := NewWaitGroup()

wg.Do(func() error { return error1 })
wg.Do(context.Background(), func(ctx context.Context) error { return error1 })

wg.Do(func() error { return nil })
wg.Do(context.Background(), func(ctx context.Context) error { return nil })

err := wg.Wait()

Expand All @@ -100,7 +100,7 @@ func TestGroup(t *testing.T) {
limitCount := 1
wg := NewWaitGroup(WaitGroupWithTaskLimit(limitCount))

wg.Do(func() error {
wg.Do(context.Background(), func(ctx context.Context) error {
// wait for assertion to do.
time.Sleep(100 * time.Millisecond)
return nil
Expand All @@ -124,7 +124,7 @@ func TestGroup(t *testing.T) {

wg := NewWaitGroup(WaitGroupWithTaskRunner(runner))

wg.Do(func() error {
wg.Do(context.Background(), func(ctx context.Context) error {
return nil
})

Expand All @@ -136,12 +136,12 @@ func TestGroup(t *testing.T) {
t.Run("set StopOnError options", func(t *testing.T) {
error1 := errors.New("error 1")

ctx, wg := NewWaitGroupWithContext(context.Background(), WaitGroupWithStopOnError())
wg := NewWaitGroup(WaitGroupWithStopOnError(), WaitGroupWithContext(context.Background()))

wg.Do(func() error { return error1 })
wg.Do(nil, func(ctx context.Context) error { return error1 })

// sample long-running and context aware task.
wg.Do(func() error {
wg.Do(nil, func(ctx context.Context) error {
for {
select {
case <-ctx.Done():
Expand All @@ -154,7 +154,7 @@ func TestGroup(t *testing.T) {

expected := NewMultiError(error1, context.Canceled)
assert.ElementsMatch(t, expected.errors, err.(*MultiError).errors)
assert.Equal(t, ctx, wg.Context())
assert.NotNil(t, wg.Context())
})
}

Expand Down

0 comments on commit e198855

Please sign in to comment.