-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
The package provides type GoroutineManager which is used to launch goroutines until context expires or the manager is stopped. Stop method blocks until all started goroutines stop. Original code by Andras https://go.dev/play/p/HhRpE-K2lA0 Adjustments and tests by Boris.
- Loading branch information
Showing
2 changed files
with
198 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
package fn | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"sync" | ||
) | ||
|
||
// ErrStopping is returned when trying to add a new goroutine while stopping. | ||
var ErrStopping = errors.New("can not add goroutine, stopping") | ||
|
||
// GoroutineManager is used to launch goroutines until context expires or the | ||
// manager is stopped. The Stop method blocks until all started goroutines stop. | ||
type GoroutineManager struct { | ||
wg sync.WaitGroup | ||
mu sync.Mutex | ||
ctx context.Context | ||
cancel func() | ||
} | ||
|
||
// NewGoroutineManager constructs and returns a new instance of | ||
// GoroutineManager. | ||
func NewGoroutineManager(ctx context.Context) *GoroutineManager { | ||
ctx, cancel := context.WithCancel(ctx) | ||
|
||
return &GoroutineManager{ | ||
ctx: ctx, | ||
cancel: cancel, | ||
} | ||
} | ||
|
||
// Go starts a new goroutine if the manager is not stopping. | ||
func (g *GoroutineManager) Go(f func(ctx context.Context)) error { | ||
// Calling wg.Add(1) and wg.Wait() when wg's counter is 0 is a race | ||
// condition, since it is not clear should Wait() block or not. This | ||
// kind of race condition is detected by Go runtime and results in a | ||
// crash if running with `-race`. To prevent this, whole Go method is | ||
// protected with a mutex. The call to wg.Wait() inside Stop() can still | ||
// run in parallel with Go, but in that case g.ctx is in expired state, | ||
// because cancel() was called in Stop, so Go returns before wg.Add(1) | ||
// call. | ||
g.mu.Lock() | ||
defer g.mu.Unlock() | ||
|
||
if g.ctx.Err() != nil { | ||
return ErrStopping | ||
} | ||
|
||
g.wg.Add(1) | ||
go func() { | ||
defer g.wg.Done() | ||
f(g.ctx) | ||
}() | ||
|
||
return nil | ||
} | ||
|
||
// Stop prevents new goroutines from being added and waits for all running | ||
// goroutines to finish. | ||
func (g *GoroutineManager) Stop() { | ||
g.mu.Lock() | ||
g.cancel() | ||
g.mu.Unlock() | ||
|
||
// Wait for all goroutines to finish. Note that this wg.Wait() call is | ||
// safe, since it can't run in parallel with wg.Add(1) call in Go, since | ||
// we just cancelled the context and even if Go call starts running here | ||
// after acquiring the mutex, it would see that the context has expired | ||
// and return ErrStopping instead of calling wg.Add(1). | ||
g.wg.Wait() | ||
} | ||
|
||
// Context returns internal context of the GoroutineManager which will expire | ||
// when either the context passed to NewGoroutineManager expires or when Stop | ||
// is called. | ||
func (g *GoroutineManager) Context() context.Context { | ||
return g.ctx | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
package fn | ||
|
||
import ( | ||
"context" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
// TestGoroutineManager tests that the GoroutineManager starts goroutines until | ||
// ctx expires. It also makes sure it fails to start new goroutines after the | ||
// context expired and the GoroutineManager is in the process of waiting for | ||
// already started goroutines in the Stop method. | ||
func TestGoroutineManager(t *testing.T) { | ||
t.Parallel() | ||
|
||
m := NewGoroutineManager(context.Background()) | ||
|
||
taskChan := make(chan struct{}) | ||
|
||
require.NoError(t, m.Go(func(ctx context.Context) { | ||
<-taskChan | ||
})) | ||
|
||
t1 := time.Now() | ||
|
||
// Close taskChan in 1s, causing the goroutine to stop. | ||
time.AfterFunc(time.Second, func() { | ||
close(taskChan) | ||
}) | ||
|
||
m.Stop() | ||
stopDelay := time.Since(t1) | ||
|
||
// Make sure Stop was waiting for the goroutine to stop. | ||
require.Greater(t, stopDelay, time.Second) | ||
|
||
// Make sure new goroutines do not start after Stop. | ||
require.ErrorIs(t, m.Go(func(ctx context.Context) {}), ErrStopping) | ||
|
||
// When Stop() is called, m.Context() expires. Test that it expired. | ||
select { | ||
case <-m.Context().Done(): | ||
default: | ||
t.Errorf("context must expire at this point") | ||
} | ||
} | ||
|
||
// TestGoroutineManagerContextExpires tests the effect of context expiry. | ||
func TestGoroutineManagerContextExpires(t *testing.T) { | ||
t.Parallel() | ||
|
||
ctx, cancel := context.WithCancel(context.Background()) | ||
|
||
m := NewGoroutineManager(ctx) | ||
|
||
require.NoError(t, m.Go(func(ctx context.Context) { | ||
<-ctx.Done() | ||
})) | ||
|
||
// The context of the manager should not expire, so the following call | ||
// must block. | ||
select { | ||
case <-m.Context().Done(): | ||
t.Errorf("context must not expire at this point") | ||
default: | ||
} | ||
|
||
cancel() | ||
|
||
// The context of the manager should expire, so the following call | ||
// must not block. | ||
select { | ||
case <-m.Context().Done(): | ||
default: | ||
t.Errorf("context must expire at this point") | ||
} | ||
|
||
// Make sure new goroutines do not start after context expiry. | ||
require.ErrorIs(t, m.Go(func(ctx context.Context) {}), ErrStopping) | ||
|
||
// Stop will wait for all goroutines to stop. | ||
m.Stop() | ||
} | ||
|
||
// TestGoroutineManagerStress starts many goroutines while calling Stop. It | ||
// is needed to make sure the GoroutineManager does not crash if this happen. | ||
// If the mutex was not used, it would crash because of a race condition between | ||
// wg.Add(1) and wg.Wait(). | ||
func TestGoroutineManagerStress(t *testing.T) { | ||
t.Parallel() | ||
|
||
m := NewGoroutineManager(context.Background()) | ||
|
||
stopChan := make(chan struct{}) | ||
|
||
time.AfterFunc(1*time.Millisecond, func() { | ||
m.Stop() | ||
close(stopChan) | ||
}) | ||
|
||
// Starts 100 goroutines sequentially. Sequential order is needed to | ||
// keep wg.counter low (0 or 1) to increase probability of race | ||
// condition to be caught if it exists. If mutex is removed in the | ||
// implementation, this test crashes under `-race`. | ||
for i := 0; i < 100; i++ { | ||
taskChan := make(chan struct{}) | ||
err := m.Go(func(ctx context.Context) { | ||
close(taskChan) | ||
}) | ||
// If goroutine was started, wait for its completion. | ||
if err == nil { | ||
<-taskChan | ||
} | ||
} | ||
|
||
// Wait for Stop to complete. | ||
<-stopChan | ||
} |