diff --git a/README.md b/README.md index 395854b..cf32590 100644 --- a/README.md +++ b/README.md @@ -126,11 +126,11 @@ if err := wg.Wait(); err != nil { ```go wg := errors.NewWaitGroup() -wg.Do(func() error { +wg.Do(func(ctx context.Context) error { return callingHttpClient() }) -wg.Do(func() error { +wg.Do(func(ctx context.Context) error { return callingHttpClient() }) @@ -144,16 +144,16 @@ if err := wg.Wait(); err != nil { ```go wg := errors.NewWaitGroup(errors.WaitGroupWithTaskLimit(2)) -wg.Do(func() error { +wg.Do(func(ctx context.Context) error { return callingHttpClient() }) -wg.Do(func() error { +wg.Do(func(ctx context.Context) error { return callingHttpClient() }) // we set limit concurrent task to 2, so this task will block until one of above are done. -wg.Do(func() error { +wg.Do(func(ctx context.Context) error { return callingHttpClient() }) @@ -172,11 +172,11 @@ import ( // in this example we are using ants goroutine pool. wg := errors.NewWaitGroup(errors.WaitGroupWithTaskRunner(ants.Submit)) -wg.Do(func() error { +wg.Do(func(ctx context.Context) error { return callingHttpClient() }) -wg.Do(func() error { +wg.Do(func(ctx context.Context) error { return callingHttpClient() }) @@ -191,41 +191,18 @@ import ( ) // in this example we are using ants goroutine pool. -wg := errors.NewWaitGroup(errors.WaitGroupWithStopOnError()) +wg := errors.NewWaitGroup(errors.WaitGroupWithStopOnError(), errors.WaitGroupWithContext(context.Background())) // you can pass your own context. -ctx := wg.Context() - -wg.Do(func() error { +wg.Do(func(ctx context.Context) error { return callingHttpClient(ctx) }) -wg.Do(func() error { +wg.Do(func(ctx context.Context) error { return callingHttpClient(ctx) }) // if one of above task failed, context will cancel and other task will stop (the task must ba aware of context cancellation like http pkg do) -if err := wg.Wait(); err != nil { - // oh, something bad happened in one of routines above. -} -``` -**or you can use NewWaitGroupWithContext method:** -```go -import ( - "github.com/mrsoftware/errors" -) - -// in this example we are using ants goroutine pool. -ctx, wg := errors.NewWaitGroupWithContext(context.Background(), errors.WaitGroupWithStopOnError()) - -wg.Do(func() error { - return callingHttpClient(ctx) -}) - -wg.Do(func() error { - return callingHttpClient(ctx) -}) - if err := wg.Wait(); err != nil { // oh, something bad happened in one of routines above. } diff --git a/waitGroup.go b/waitGroup.go index 17a0fac..cf02bc9 100644 --- a/waitGroup.go +++ b/waitGroup.go @@ -14,6 +14,7 @@ type WaitGroup struct { ctx context.Context cancel context.CancelCauseFunc cancelOnce sync.Once + wait chan struct{} } // NewWaitGroup create new WaitGroup. @@ -35,7 +36,7 @@ func NewWaitGroup(options ...WaitGroupOption) *WaitGroup { ctx, cancel := context.WithCancelCause(ops.Ctx) - return &WaitGroup{options: ops, gch: gch, ctx: ctx, cancel: cancel} + return &WaitGroup{options: ops, gch: gch, ctx: ctx, cancel: cancel, wait: make(chan struct{})} } // Context of current waitGroup. @@ -54,11 +55,22 @@ func (g *WaitGroup) Wait() (err error) { defer func() { g.Stop(err) }() - if g.errors.Len() == 0 { - return nil + select { + case g.wait <- struct{}{}: + default: } - return &g.errors + return g.errors.Err() +} + +// WaitChan is like Wait method but return chanel. +func (g *WaitGroup) WaitChan() <-chan struct{} { + return g.wait +} + +// Err of tasks. +func (g *WaitGroup) Err() error { + return g.errors.Err() } // Add is sync.WaitGroup.Add. @@ -81,8 +93,14 @@ func (g *WaitGroup) Done(err error) { g.errors.Add(err) } -// Do run the given function in a new goroutine. -func (g *WaitGroup) Do(ctx context.Context, f func(ctx context.Context) error) { +// Do run the given function in a new goroutine with internal context. +func (g *WaitGroup) Do(f func(ctx context.Context) error) { + g.DoWithContext(g.ctx, f) +} + +// DoWithContext run the given function in a new goroutine with custom context. +// if passed context is nil, use the internal context. +func (g *WaitGroup) DoWithContext(ctx context.Context, f func(ctx context.Context) error) { if ctx == nil { ctx = g.ctx } diff --git a/waitGroup_test.go b/waitGroup_test.go index 78cccfe..ea7f425 100644 --- a/waitGroup_test.go +++ b/waitGroup_test.go @@ -85,9 +85,9 @@ func TestGroup(t *testing.T) { wg := NewWaitGroup() - wg.Do(context.Background(), func(ctx context.Context) error { return error1 }) + wg.Do(func(ctx context.Context) error { return error1 }) - wg.Do(context.Background(), func(ctx context.Context) error { return nil }) + wg.Do(func(ctx context.Context) error { return nil }) err := wg.Wait() @@ -100,7 +100,7 @@ func TestGroup(t *testing.T) { limitCount := 1 wg := NewWaitGroup(WaitGroupWithTaskLimit(limitCount)) - wg.Do(context.Background(), func(ctx context.Context) error { + wg.Do(func(ctx context.Context) error { // wait for assertion to do. time.Sleep(100 * time.Millisecond) return nil @@ -124,7 +124,7 @@ func TestGroup(t *testing.T) { wg := NewWaitGroup(WaitGroupWithTaskRunner(runner)) - wg.Do(context.Background(), func(ctx context.Context) error { + wg.Do(func(ctx context.Context) error { return nil }) @@ -138,10 +138,10 @@ func TestGroup(t *testing.T) { wg := NewWaitGroup(WaitGroupWithStopOnError(), WaitGroupWithContext(context.Background())) - wg.Do(nil, func(ctx context.Context) error { return error1 }) + wg.Do(func(ctx context.Context) error { return error1 }) // sample long-running and context aware task. - wg.Do(nil, func(ctx context.Context) error { + wg.Do(func(ctx context.Context) error { for { select { case <-ctx.Done():