Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add concurrent safe StubTime test helper #298

Merged
merged 1 commit into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions internal/riverinternaltest/riverinternaltest.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,29 @@ func TestDB(ctx context.Context, tb testing.TB) *pgxpool.Pool {
return testPool.Pool()
}

// StubTime is a shortcut for stubbing time for a the given archetype at the
// given time. It returns the time given as argument for convenience.
func StubTime(archetype *baseservice.Archetype, t time.Time) time.Time {
// Strip monotonic portion and make UTC so that comparisons are less fraught.
t = t.Round(0).UTC()

archetype.TimeNowUTC = func() time.Time { return t }
return t
// StubTime returns a pair of function for (getTime, setTime), the former of
// which is compatible with `TimeNowUTC` found in the service archetype.
// It's concurrent safe is so that a started service can access its stub
// time while the test case is setting it, and without the race detector
// triggering.
func StubTime(initialTime time.Time) (func() time.Time, func(t time.Time)) {
var (
mu sync.RWMutex
stubbedTime = &initialTime
)

getTime := func() time.Time {
mu.RLock()
defer mu.RUnlock()
return *stubbedTime
}
setTime := func(newTime time.Time) {
mu.Lock()
defer mu.Unlock()
stubbedTime = &newTime
}

return getTime, setTime
}

// A pool and mutex to protect it, lazily initialized by TestTx. Once open, this
Expand Down
33 changes: 33 additions & 0 deletions internal/riverinternaltest/riverinternaltest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,45 @@ import (
"context"
"sync"
"testing"
"time"

"github.com/jackc/pgerrcode"
"github.com/jackc/pgx/v5/pgconn"
"github.com/stretchr/testify/require"
)

func TestStubTime(t *testing.T) {
t.Parallel()

t.Run("BasicUsage", func(t *testing.T) {
t.Parallel()

initialTime := time.Now()

getTime, setTime := StubTime(initialTime)
require.Equal(t, initialTime, getTime())

newTime := initialTime.Add(1 * time.Second)
setTime(newTime)
require.Equal(t, newTime, getTime())
})

t.Run("Stress", func(t *testing.T) {
t.Parallel()

getTime, setTime := StubTime(time.Now())

for i := 0; i < 10; i++ {
go func() {
for j := 0; j < 50; j++ {
setTime(time.Now())
_ = getTime()
}
}()
}
})
}

// Implemented by `pgx.Tx` or `pgxpool.Pool`. Normally we'd use a similar type
// from `dbsqlc` or `dbutil`, but riverinternaltest is extremely low level and
// that would introduce a cyclic dependency. We could package as
Expand Down
4 changes: 2 additions & 2 deletions job_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ func TestJobExecutor_Execute(t *testing.T) {

executor, bundle := setup(t)

baselineTime := riverinternaltest.StubTime(&executor.Archetype, time.Now())
executor.Archetype.TimeNowUTC, _ = riverinternaltest.StubTime(time.Now().UTC())

workerErr := errors.New("job error")
executor.WorkUnit = newWorkUnitFactoryWithCustomRetry(func() error { return workerErr }, nil).MakeUnit(bundle.jobRow)
Expand All @@ -239,7 +239,7 @@ func TestJobExecutor_Execute(t *testing.T) {
require.WithinDuration(t, executor.ClientRetryPolicy.NextRetry(bundle.jobRow), job.ScheduledAt, 1*time.Second)
require.Equal(t, rivertype.JobStateRetryable, job.State)
require.Len(t, job.Errors, 1)
require.Equal(t, baselineTime, job.Errors[0].At)
require.Equal(t, executor.Archetype.TimeNowUTC().Truncate(1*time.Microsecond), job.Errors[0].At.Truncate(1*time.Microsecond))
require.Equal(t, 1, job.Errors[0].Attempt)
require.Equal(t, "job error", job.Errors[0].Error)
require.Equal(t, "", job.Errors[0].Trace)
Expand Down
Loading