Skip to content

Commit

Permalink
handle done chanel like context done chanel
Browse files Browse the repository at this point in the history
  • Loading branch information
mrsoftware committed Oct 8, 2024
1 parent e198855 commit c919624
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 45 deletions.
43 changes: 10 additions & 33 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,11 @@ if err := wg.Wait(); err != nil {
```go
wg := errors.NewWaitGroup()

wg.Do(func() error {
wg.Do(func(ctx context.Context) error {
return callingHttpClient()
})

wg.Do(func() error {
wg.Do(func(ctx context.Context) error {
return callingHttpClient()
})

Expand All @@ -144,16 +144,16 @@ if err := wg.Wait(); err != nil {
```go
wg := errors.NewWaitGroup(errors.WaitGroupWithTaskLimit(2))

wg.Do(func() error {
wg.Do(func(ctx context.Context) error {
return callingHttpClient()
})

wg.Do(func() error {
wg.Do(func(ctx context.Context) error {
return callingHttpClient()
})

// we set limit concurrent task to 2, so this task will block until one of above are done.
wg.Do(func() error {
wg.Do(func(ctx context.Context) error {
return callingHttpClient()
})

Expand All @@ -172,11 +172,11 @@ import (
// in this example we are using ants goroutine pool.
wg := errors.NewWaitGroup(errors.WaitGroupWithTaskRunner(ants.Submit))

wg.Do(func() error {
wg.Do(func(ctx context.Context) error {
return callingHttpClient()
})

wg.Do(func() error {
wg.Do(func(ctx context.Context) error {
return callingHttpClient()
})

Expand All @@ -191,41 +191,18 @@ import (
)

// in this example we are using ants goroutine pool.
wg := errors.NewWaitGroup(errors.WaitGroupWithStopOnError())
wg := errors.NewWaitGroup(errors.WaitGroupWithStopOnError(), errors.WaitGroupWithContext(context.Background())) // you can pass your own context.

ctx := wg.Context()

wg.Do(func() error {
wg.Do(func(ctx context.Context) error {
return callingHttpClient(ctx)
})

wg.Do(func() error {
wg.Do(func(ctx context.Context) error {
return callingHttpClient(ctx)
})

// if one of above task failed, context will cancel and other task will stop (the task must ba aware of context cancellation like http pkg do)

if err := wg.Wait(); err != nil {
// oh, something bad happened in one of routines above.
}
```
**or you can use NewWaitGroupWithContext method:**
```go
import (
"github.com/mrsoftware/errors"
)

// in this example we are using ants goroutine pool.
ctx, wg := errors.NewWaitGroupWithContext(context.Background(), errors.WaitGroupWithStopOnError())

wg.Do(func() error {
return callingHttpClient(ctx)
})

wg.Do(func() error {
return callingHttpClient(ctx)
})

if err := wg.Wait(); err != nil {
// oh, something bad happened in one of routines above.
}
Expand Down
30 changes: 24 additions & 6 deletions waitGroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ type WaitGroup struct {
ctx context.Context
cancel context.CancelCauseFunc
cancelOnce sync.Once
wait chan struct{}
}

// NewWaitGroup create new WaitGroup.
Expand All @@ -35,7 +36,7 @@ func NewWaitGroup(options ...WaitGroupOption) *WaitGroup {

ctx, cancel := context.WithCancelCause(ops.Ctx)

return &WaitGroup{options: ops, gch: gch, ctx: ctx, cancel: cancel}
return &WaitGroup{options: ops, gch: gch, ctx: ctx, cancel: cancel, wait: make(chan struct{})}
}

// Context of current waitGroup.
Expand All @@ -54,11 +55,22 @@ func (g *WaitGroup) Wait() (err error) {

defer func() { g.Stop(err) }()

if g.errors.Len() == 0 {
return nil
select {
case g.wait <- struct{}{}:
default:
}

return &g.errors
return g.errors.Err()
}

// WaitChan is like Wait method but return chanel.
func (g *WaitGroup) WaitChan() <-chan struct{} {
return g.wait
}

// Err of tasks.
func (g *WaitGroup) Err() error {
return g.errors.Err()
}

// Add is sync.WaitGroup.Add.
Expand All @@ -81,8 +93,14 @@ func (g *WaitGroup) Done(err error) {
g.errors.Add(err)
}

// Do run the given function in a new goroutine.
func (g *WaitGroup) Do(ctx context.Context, f func(ctx context.Context) error) {
// Do run the given function in a new goroutine with internal context.
func (g *WaitGroup) Do(f func(ctx context.Context) error) {
g.DoWithContext(g.ctx, f)
}

// DoWithContext run the given function in a new goroutine with custom context.
// if passed context is nil, use the internal context.
func (g *WaitGroup) DoWithContext(ctx context.Context, f func(ctx context.Context) error) {
if ctx == nil {
ctx = g.ctx
}
Expand Down
12 changes: 6 additions & 6 deletions waitGroup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ func TestGroup(t *testing.T) {

wg := NewWaitGroup()

wg.Do(context.Background(), func(ctx context.Context) error { return error1 })
wg.Do(func(ctx context.Context) error { return error1 })

wg.Do(context.Background(), func(ctx context.Context) error { return nil })
wg.Do(func(ctx context.Context) error { return nil })

err := wg.Wait()

Expand All @@ -100,7 +100,7 @@ func TestGroup(t *testing.T) {
limitCount := 1
wg := NewWaitGroup(WaitGroupWithTaskLimit(limitCount))

wg.Do(context.Background(), func(ctx context.Context) error {
wg.Do(func(ctx context.Context) error {
// wait for assertion to do.
time.Sleep(100 * time.Millisecond)
return nil
Expand All @@ -124,7 +124,7 @@ func TestGroup(t *testing.T) {

wg := NewWaitGroup(WaitGroupWithTaskRunner(runner))

wg.Do(context.Background(), func(ctx context.Context) error {
wg.Do(func(ctx context.Context) error {
return nil
})

Expand All @@ -138,10 +138,10 @@ func TestGroup(t *testing.T) {

wg := NewWaitGroup(WaitGroupWithStopOnError(), WaitGroupWithContext(context.Background()))

wg.Do(nil, func(ctx context.Context) error { return error1 })
wg.Do(func(ctx context.Context) error { return error1 })

// sample long-running and context aware task.
wg.Do(nil, func(ctx context.Context) error {
wg.Do(func(ctx context.Context) error {
for {
select {
case <-ctx.Done():
Expand Down

0 comments on commit c919624

Please sign in to comment.