Skip to content

Commit

Permalink
Only abort tx's when last connection is closed
Browse files Browse the repository at this point in the history
Also: convert rest of panic isn't ordinary errors
  • Loading branch information
reductionista committed Nov 2, 2024
1 parent 8e479fe commit 9035c2a
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 32 deletions.
2 changes: 1 addition & 1 deletion pkg/pg/pg.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

func NewSqlxDB(t testing.TB, dbURL string) *sqlx.DB {
tests.SkipShortDB(t)
err := RegisterTxDb(tests.Context(t), dbURL)
err := RegisterTxDb(dbURL)
if err != nil {
t.Errorf("failed to register txdb dialect: %s", err.Error())
return nil
Expand Down
54 changes: 25 additions & 29 deletions pkg/pg/txdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"database/sql"
"database/sql/driver"
"flag"
"fmt"
"io"
"net/url"
Expand Down Expand Up @@ -34,18 +33,14 @@ import (
// store to use the raw DialectPostgres dialect and setup a one-use database.
// See heavyweight.FullTestDB() as a convenience function to help you do this,
// but please use sparingly because as it's name implies, it is expensive.
func RegisterTxDb(ctx context.Context, dbURL string) error {
func RegisterTxDb(dbURL string) error {
drivers := sql.Drivers()
for _, driver := range drivers {
if driver == string(TransactionWrappedPostgres) {
// TxDB driver already registered
return nil
}
}
testing.Init()
if !flag.Parsed() {
flag.Parse()
}
if testing.Short() {
// -short tests don't need a DB
return nil
Expand All @@ -60,15 +55,10 @@ func RegisterTxDb(ctx context.Context, dbURL string) error {
if !strings.HasSuffix(parsed.Path, "_test") {
return fmt.Errorf("cannot run tests against database named `%s`. Note that the test database MUST end in `_test` to differentiate from a possible production DB. HINT: Try postgresql://postgres@localhost:5432/chainlink_test?sslmode=disable", parsed.Path[1:])
}
abort := make(chan struct{})
go func() {
<-ctx.Done()
abort <- struct{}{} // abort all queries when context is cancelled
}()

name := string(TransactionWrappedPostgres)
sql.Register(name, &txDriver{
abort: abort,
abort: make(chan struct{}),
dbURL: dbURL,
conns: make(map[string]*conn),
})
Expand All @@ -85,7 +75,7 @@ var _ driver.SessionResetter = &conn{}
// When `Close` is called, transaction is rolled back.
type txDriver struct {
sync.Mutex
abort <-chan struct{}
abort chan struct{}
db *sql.DB
conns map[string]*conn

Expand Down Expand Up @@ -130,6 +120,7 @@ func (d *txDriver) deleteConn(c *conn) error {
}
delete(d.conns, c.dsn)
if len(d.conns) == 0 && d.db != nil {
close(d.abort)
if err := d.db.Close(); err != nil {
return err
}
Expand All @@ -152,7 +143,7 @@ func (c *conn) Begin() (driver.Tx, error) {
c.Lock()
defer c.Unlock()
if c.closed {
panic("conn is closed")
return nil, fmt.Errorf("conn is closed")
}
// Begin is a noop because the transaction was already opened
return tx{c.tx}, nil
Expand All @@ -177,7 +168,7 @@ func (c *conn) PrepareContext(_ context.Context, query string) (driver.Stmt, err
c.Lock()
defer c.Unlock()
if c.closed {
panic("conn is closed")
return nil, fmt.Errorf("conn is closed")
}

// It is not safe to give the passed in context to the tx directly
Expand Down Expand Up @@ -226,7 +217,7 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
c.Lock()
defer c.Unlock()
if c.closed {
panic("conn is closed")
return nil, fmt.Errorf("conn is closed")
}

ctx, cancel := utils.ContextFromChan(c.abort)
Expand Down Expand Up @@ -277,38 +268,43 @@ func (c *conn) tryOpen() bool {
// Drivers must ensure all network calls made by Close
// do not block indefinitely (e.g. apply a timeout).
func (c *conn) Close() (err error) {
if !c.close() {
return
newlyClosed, err := c.close()
if err != nil {
return err
}
if !newlyClosed {
return nil
}

// Wait to remove self to avoid nesting locks.
if err := c.removeSelf(); err != nil {
panic(err)
if err = c.removeSelf(); err != nil {
return err
}
return
}

func (c *conn) close() bool {
func (c *conn) close() (bool, error) {
c.Lock()
defer c.Unlock()
if c.closed {
// Double close, should be a safe to make this a noop
// PGX allows double close
// See: https://github.com/jackc/pgx/blob/a457da8bffa4f90ad672fa093ee87f20cf06687b/conn.go#L249
return false
return false, nil
}

c.opened--
if c.opened > 0 {
return false
return false, nil
}
if c.tx != nil {
if err := c.tx.Rollback(); err != nil {
panic(err)
return false, err
}
c.tx = nil
}
c.closed = true
return true
return true, nil
}

type tx struct {
Expand All @@ -335,7 +331,7 @@ func (s stmt) Exec(args []driver.Value) (driver.Result, error) {
s.conn.Lock()
defer s.conn.Unlock()
if s.conn.closed {
panic("conn is closed")
return nil, fmt.Errorf("conn is closed")
}
return s.st.Exec(mapArgs(args)...)
}
Expand All @@ -345,7 +341,7 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
s.conn.Lock()
defer s.conn.Unlock()
if s.conn.closed {
panic("conn is closed")
return nil, fmt.Errorf("conn is closed")
}

ctx, cancel := utils.ContextFromChan(s.abort)
Expand All @@ -370,7 +366,7 @@ func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
s.conn.Lock()
defer s.conn.Unlock()
if s.conn.closed {
panic("conn is closed")
return nil, fmt.Errorf("conn is closed")
}
rows, err := s.st.Query(mapArgs(args)...)
defer func() {
Expand All @@ -387,7 +383,7 @@ func (s *stmt) QueryContext(_ context.Context, args []driver.NamedValue) (driver
s.conn.Lock()
defer s.conn.Unlock()
if s.conn.closed {
panic("conn is closed")
return nil, fmt.Errorf("conn is closed")
}

ctx, cancel := utils.ContextFromChan(s.abort)
Expand Down
4 changes: 2 additions & 2 deletions pkg/pg/txdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ func TestTxDBDriver(t *testing.T) {
})

t.Run("Make sure calling sql.Register() can be called twice", func(t *testing.T) {
require.NoError(t, RegisterTxDb(tests.Context(t), "foo"))
require.NoError(t, RegisterTxDb(tests.Context(t), "bar"))
require.NoError(t, RegisterTxDb("foo"))
require.NoError(t, RegisterTxDb("bar"))
drivers := sql.Drivers()
assert.Contains(t, drivers, "txdb")
})
Expand Down

0 comments on commit 9035c2a

Please sign in to comment.