Skip to content

Commit

Permalink
Replace atomic.Uint32 with atomic.Bool and use CompareAndSwap there i…
Browse files Browse the repository at this point in the history
…t's possible.

Replace random delay with constan to make test not blink.
Simplify assertion in test to make it stable.
  • Loading branch information
alexeykiselev committed Dec 10, 2024
1 parent 70a8c34 commit 00a9ebe
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 15 deletions.
13 changes: 5 additions & 8 deletions pkg/execution/taskgroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ import (
type TaskGroup struct {
wg sync.WaitGroup // Counter for active goroutines.

// active is nonzero when the group is "active", meaning there has been at least one call to Run since the group
// active is true when the group is "active", meaning there has been at least one call to Run since the group
// was created or the last Wait.
//
// Together active and errLock work as a kind of resettable sync.Once. The fast path reads active and only
// acquires errLock if it discovers setup is needed.
active atomic.Uint32
active atomic.Bool

errLock sync.Mutex // Guards the fields below.
err error // First captured error returned from Wait.
Expand Down Expand Up @@ -56,7 +56,7 @@ func (g *TaskGroup) OnError(handler func(error) error) *TaskGroup {
// so the [execute] function should include the interruption logic.
func (g *TaskGroup) Run(execute func() error) {
g.wg.Add(1)
if g.active.Load() == 0 {
if !g.active.Load() {
g.activate()
}
go func() {
Expand All @@ -82,9 +82,7 @@ func (g *TaskGroup) Wait() error {
defer g.errLock.Unlock()

// If the group is still active, deactivate it now.
if g.active.Load() != 0 {
g.active.Store(0)
}
g.active.CompareAndSwap(true, false)
return g.err
}

Expand All @@ -93,9 +91,8 @@ func (g *TaskGroup) Wait() error {
func (g *TaskGroup) activate() {
g.errLock.Lock()
defer g.errLock.Unlock()
if g.active.Load() == 0 {
if g.active.CompareAndSwap(false, true) {
g.err = nil
g.active.Store(1)
}
}

Expand Down
8 changes: 1 addition & 7 deletions pkg/execution/taskgroup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"errors"
"math/rand/v2"
"runtime"
"sync"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -83,8 +82,6 @@ func TestCancelPropagation(t *testing.T) {
}
})
}
runtime.Gosched()
<-time.After(500 * time.Microsecond)
cancel()

err := g.Wait()
Expand All @@ -102,9 +99,6 @@ func TestCancelPropagation(t *testing.T) {
}
}

assert.NotZero(t, numOK)
assert.NotZero(t, numCanceled)
assert.NotZero(t, numOther)
total := int(numOK) + numCanceled + numOther
assert.Equal(t, numTasks, total)
}
Expand All @@ -119,7 +113,7 @@ func TestWaitingForFinish(t *testing.T) {
select {
case <-ctx.Done():
return work(50, nil)()
case <-time.After(randomDuration(60)):
case <-time.After(60 * time.Millisecond):
return failure
}
}
Expand Down

0 comments on commit 00a9ebe

Please sign in to comment.