diff --git a/internal/riverinternaltest/sharedtx/shared_tx.go b/internal/riverinternaltest/sharedtx/shared_tx.go index 7c40a83b..a9f3a086 100644 --- a/internal/riverinternaltest/sharedtx/shared_tx.go +++ b/internal/riverinternaltest/sharedtx/shared_tx.go @@ -73,10 +73,11 @@ func (e *SharedTx) Exec(ctx context.Context, query string, args ...any) (pgconn. func (e *SharedTx) Query(ctx context.Context, query string, args ...any) (pgx.Rows, error) { e.lock() - // no unlock until rows close + // no unlock until rows close or return on error condition rows, err := e.inner.Query(ctx, query, args...) if err != nil { + e.unlock() return nil, err } diff --git a/internal/riverinternaltest/sharedtx/shared_tx_test.go b/internal/riverinternaltest/sharedtx/shared_tx_test.go index 180beae3..e63800da 100644 --- a/internal/riverinternaltest/sharedtx/shared_tx_test.go +++ b/internal/riverinternaltest/sharedtx/shared_tx_test.go @@ -137,4 +137,27 @@ func TestSharedTx(t *testing.T) { require.Len(t, sharedTx.wait, 1) }) + + // Checks specifically that the shared transaction is unlocked correctly on + // the Query function's error path (normally it's unlocked when the returned + // rows struct is closed, so an additional unlock operation is required). + t.Run("QueryUnlocksOnError", func(t *testing.T) { + t.Parallel() + + sharedTx := setup(t) + + { + // Roll back the transaction so using it returns an error. + require.NoError(t, sharedTx.inner.Rollback(ctx)) + + _, err := sharedTx.Query(ctx, "SELECT 1") //nolint:sqlclosecheck + require.ErrorIs(t, err, pgx.ErrTxClosed) + + select { + case <-sharedTx.wait: + default: + require.FailNow(t, "Should have been a value in shared transaction's wait channel") + } + } + }) }