diff --git a/pkg/pg/pg.go b/pkg/pg/pg.go index af540f6f2..a82f052a1 100644 --- a/pkg/pg/pg.go +++ b/pkg/pg/pg.go @@ -14,7 +14,7 @@ import ( func NewSqlxDB(t testing.TB, dbURL string) *sqlx.DB { tests.SkipShortDB(t) - err := RegisterTxDb(tests.Context(t), dbURL) + err := RegisterTxDb(dbURL) if err != nil { t.Errorf("failed to register txdb dialect: %s", err.Error()) return nil diff --git a/pkg/pg/txdb.go b/pkg/pg/txdb.go index f03fd5db9..75566c86a 100644 --- a/pkg/pg/txdb.go +++ b/pkg/pg/txdb.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" "database/sql/driver" - "flag" "fmt" "io" "net/url" @@ -34,7 +33,7 @@ import ( // store to use the raw DialectPostgres dialect and setup a one-use database. // See heavyweight.FullTestDB() as a convenience function to help you do this, // but please use sparingly because as it's name implies, it is expensive. -func RegisterTxDb(ctx context.Context, dbURL string) error { +func RegisterTxDb(dbURL string) error { drivers := sql.Drivers() for _, driver := range drivers { if driver == string(TransactionWrappedPostgres) { @@ -42,10 +41,6 @@ func RegisterTxDb(ctx context.Context, dbURL string) error { return nil } } - testing.Init() - if !flag.Parsed() { - flag.Parse() - } if testing.Short() { // -short tests don't need a DB return nil @@ -60,15 +55,10 @@ func RegisterTxDb(ctx context.Context, dbURL string) error { if !strings.HasSuffix(parsed.Path, "_test") { return fmt.Errorf("cannot run tests against database named `%s`. Note that the test database MUST end in `_test` to differentiate from a possible production DB. HINT: Try postgresql://postgres@localhost:5432/chainlink_test?sslmode=disable", parsed.Path[1:]) } - abort := make(chan struct{}) - go func() { - <-ctx.Done() - abort <- struct{}{} // abort all queries when context is cancelled - }() name := string(TransactionWrappedPostgres) sql.Register(name, &txDriver{ - abort: abort, + abort: make(chan struct{}), dbURL: dbURL, conns: make(map[string]*conn), }) @@ -85,7 +75,7 @@ var _ driver.SessionResetter = &conn{} // When `Close` is called, transaction is rolled back. type txDriver struct { sync.Mutex - abort <-chan struct{} + abort chan struct{} db *sql.DB conns map[string]*conn @@ -130,6 +120,7 @@ func (d *txDriver) deleteConn(c *conn) error { } delete(d.conns, c.dsn) if len(d.conns) == 0 && d.db != nil { + close(d.abort) if err := d.db.Close(); err != nil { return err } @@ -152,7 +143,7 @@ func (c *conn) Begin() (driver.Tx, error) { c.Lock() defer c.Unlock() if c.closed { - panic("conn is closed") + return nil, fmt.Errorf("conn is closed") } // Begin is a noop because the transaction was already opened return tx{c.tx}, nil @@ -177,7 +168,7 @@ func (c *conn) PrepareContext(_ context.Context, query string) (driver.Stmt, err c.Lock() defer c.Unlock() if c.closed { - panic("conn is closed") + return nil, fmt.Errorf("conn is closed") } // It is not safe to give the passed in context to the tx directly @@ -226,7 +217,7 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam c.Lock() defer c.Unlock() if c.closed { - panic("conn is closed") + return nil, fmt.Errorf("conn is closed") } ctx, cancel := utils.ContextFromChan(c.abort) @@ -277,38 +268,43 @@ func (c *conn) tryOpen() bool { // Drivers must ensure all network calls made by Close // do not block indefinitely (e.g. apply a timeout). func (c *conn) Close() (err error) { - if !c.close() { - return + newlyClosed, err := c.close() + if err != nil { + return err + } + if !newlyClosed { + return nil } + // Wait to remove self to avoid nesting locks. - if err := c.removeSelf(); err != nil { - panic(err) + if err = c.removeSelf(); err != nil { + return err } return } -func (c *conn) close() bool { +func (c *conn) close() (bool, error) { c.Lock() defer c.Unlock() if c.closed { // Double close, should be a safe to make this a noop // PGX allows double close // See: https://github.com/jackc/pgx/blob/a457da8bffa4f90ad672fa093ee87f20cf06687b/conn.go#L249 - return false + return false, nil } c.opened-- if c.opened > 0 { - return false + return false, nil } if c.tx != nil { if err := c.tx.Rollback(); err != nil { - panic(err) + return false, err } c.tx = nil } c.closed = true - return true + return true, nil } type tx struct { @@ -335,7 +331,7 @@ func (s stmt) Exec(args []driver.Value) (driver.Result, error) { s.conn.Lock() defer s.conn.Unlock() if s.conn.closed { - panic("conn is closed") + return nil, fmt.Errorf("conn is closed") } return s.st.Exec(mapArgs(args)...) } @@ -345,7 +341,7 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive s.conn.Lock() defer s.conn.Unlock() if s.conn.closed { - panic("conn is closed") + return nil, fmt.Errorf("conn is closed") } ctx, cancel := utils.ContextFromChan(s.abort) @@ -370,7 +366,7 @@ func (s stmt) Query(args []driver.Value) (driver.Rows, error) { s.conn.Lock() defer s.conn.Unlock() if s.conn.closed { - panic("conn is closed") + return nil, fmt.Errorf("conn is closed") } rows, err := s.st.Query(mapArgs(args)...) defer func() { @@ -387,7 +383,7 @@ func (s *stmt) QueryContext(_ context.Context, args []driver.NamedValue) (driver s.conn.Lock() defer s.conn.Unlock() if s.conn.closed { - panic("conn is closed") + return nil, fmt.Errorf("conn is closed") } ctx, cancel := utils.ContextFromChan(s.abort) diff --git a/pkg/pg/txdb_test.go b/pkg/pg/txdb_test.go index 431a6746f..703b53336 100644 --- a/pkg/pg/txdb_test.go +++ b/pkg/pg/txdb_test.go @@ -55,8 +55,8 @@ func TestTxDBDriver(t *testing.T) { }) t.Run("Make sure calling sql.Register() can be called twice", func(t *testing.T) { - require.NoError(t, RegisterTxDb(tests.Context(t), "foo")) - require.NoError(t, RegisterTxDb(tests.Context(t), "bar")) + require.NoError(t, RegisterTxDb("foo")) + require.NoError(t, RegisterTxDb("bar")) drivers := sql.Drivers() assert.Contains(t, drivers, "txdb") })