From 8e479fe592ec57de02020c881446aa3752899273 Mon Sep 17 00:00:00 2001 From: Domino Valdano <2644901+reductionista@users.noreply.github.com> Date: Fri, 1 Nov 2024 11:32:18 -0700 Subject: [PATCH] Create ctx from testing context, instead of using context.Background --- pkg/pg/pg.go | 2 +- pkg/pg/txdb.go | 68 ++++++++++++++++++++++++++++++++------------- pkg/pg/txdb_test.go | 8 ++++++ 3 files changed, 57 insertions(+), 21 deletions(-) diff --git a/pkg/pg/pg.go b/pkg/pg/pg.go index a82f052a1..af540f6f2 100644 --- a/pkg/pg/pg.go +++ b/pkg/pg/pg.go @@ -14,7 +14,7 @@ import ( func NewSqlxDB(t testing.TB, dbURL string) *sqlx.DB { tests.SkipShortDB(t) - err := RegisterTxDb(dbURL) + err := RegisterTxDb(tests.Context(t), dbURL) if err != nil { t.Errorf("failed to register txdb dialect: %s", err.Error()) return nil diff --git a/pkg/pg/txdb.go b/pkg/pg/txdb.go index 243e9025d..f03fd5db9 100644 --- a/pkg/pg/txdb.go +++ b/pkg/pg/txdb.go @@ -14,6 +14,8 @@ import ( "github.com/jmoiron/sqlx" "go.uber.org/multierr" + + "github.com/smartcontractkit/chainlink-common/pkg/utils" ) // txdb is a simplified version of https://github.com/DATA-DOG/go-txdb @@ -32,7 +34,7 @@ import ( // store to use the raw DialectPostgres dialect and setup a one-use database. // See heavyweight.FullTestDB() as a convenience function to help you do this, // but please use sparingly because as it's name implies, it is expensive. -func RegisterTxDb(dbURL string) error { +func RegisterTxDb(ctx context.Context, dbURL string) error { drivers := sql.Drivers() for _, driver := range drivers { if driver == string(TransactionWrappedPostgres) { @@ -58,8 +60,15 @@ func RegisterTxDb(dbURL string) error { if !strings.HasSuffix(parsed.Path, "_test") { return fmt.Errorf("cannot run tests against database named `%s`. Note that the test database MUST end in `_test` to differentiate from a possible production DB. HINT: Try postgresql://postgres@localhost:5432/chainlink_test?sslmode=disable", parsed.Path[1:]) } + abort := make(chan struct{}) + go func() { + <-ctx.Done() + abort <- struct{}{} // abort all queries when context is cancelled + }() + name := string(TransactionWrappedPostgres) sql.Register(name, &txDriver{ + abort: abort, dbURL: dbURL, conns: make(map[string]*conn), }) @@ -76,6 +85,7 @@ var _ driver.SessionResetter = &conn{} // When `Close` is called, transaction is rolled back. type txDriver struct { sync.Mutex + abort <-chan struct{} db *sql.DB conns map[string]*conn @@ -99,7 +109,7 @@ func (d *txDriver) Open(dsn string) (driver.Conn, error) { if err != nil { return nil, err } - c = &conn{tx: tx, opened: 1, dsn: dsn} + c = &conn{abort: d.abort, tx: tx, opened: 1, dsn: dsn} c.removeSelf = func() error { return d.deleteConn(c) } @@ -130,6 +140,7 @@ func (d *txDriver) deleteConn(c *conn) error { type conn struct { sync.Mutex + abort <-chan struct{} dsn string tx *sql.Tx // tx may be shared by many conns, definitive one lives in the map keyed by DSN on the txDriver. Do not modify from conn closed bool @@ -156,26 +167,32 @@ func (c *conn) BeginTx(_ context.Context, opts driver.TxOptions) (driver.Tx, err // Prepare returns a prepared statement, bound to this connection. func (c *conn) Prepare(query string) (driver.Stmt, error) { - return c.PrepareContext(context.Background(), query) + ctx, cancel := utils.ContextFromChan(c.abort) + defer cancel() + return c.PrepareContext(ctx, query) } // Implement the "ConnPrepareContext" interface -func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { +func (c *conn) PrepareContext(_ context.Context, query string) (driver.Stmt, error) { c.Lock() defer c.Unlock() if c.closed { panic("conn is closed") } - // TODO: Fix context handling - // FIXME: It is not safe to give the passed in context to the tx directly + // It is not safe to give the passed in context to the tx directly // because the tx is shared by many conns and cancelling the context will - // destroy the tx which can affect other conns - st, err := c.tx.PrepareContext(context.Background(), query) + // destroy the tx which can affect other conns. Instead, we pass the context + // passed to NewSqlxDb when the database was set up so the operation can at + // least be aborted immediately if the whole test is interrupted. + ctx, cancel := utils.ContextFromChan(c.abort) + defer cancel() + + st, err := c.tx.PrepareContext(ctx, query) if err != nil { return nil, err } - return &stmt{st, c}, nil + return &stmt{c.abort, st, c}, nil } // IsValid is called prior to placing the connection into the @@ -212,8 +229,10 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam panic("conn is closed") } - // TODO: Fix context handling - rs, err := c.tx.QueryContext(context.Background(), query, mapNamedArgs(args)...) + ctx, cancel := utils.ContextFromChan(c.abort) + defer cancel() + + rs, err := c.tx.QueryContext(ctx, query, mapNamedArgs(args)...) if err != nil { return nil, err } @@ -229,8 +248,10 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name if c.closed { return nil, fmt.Errorf("conn is closed") } - // TODO: Fix context handling - return c.tx.ExecContext(context.Background(), query, mapNamedArgs(args)...) + ctx, cancel := utils.ContextFromChan(c.abort) + defer cancel() + + return c.tx.ExecContext(ctx, query, mapNamedArgs(args)...) } // tryOpen attempts to increment the open count, but returns false if closed. @@ -305,8 +326,9 @@ func (tx tx) Rollback() error { } type stmt struct { - st *sql.Stmt - conn *conn + abort <-chan struct{} + st *sql.Stmt + conn *conn } func (s stmt) Exec(args []driver.Value) (driver.Result, error) { @@ -325,8 +347,11 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive if s.conn.closed { panic("conn is closed") } - // TODO: Fix context handling - return s.st.ExecContext(context.Background(), mapNamedArgs(args)...) + + ctx, cancel := utils.ContextFromChan(s.abort) + defer cancel() + + return s.st.ExecContext(ctx, mapNamedArgs(args)...) } func mapArgs(args []driver.Value) (res []interface{}) { @@ -358,14 +383,17 @@ func (s stmt) Query(args []driver.Value) (driver.Rows, error) { } // Implement the "StmtQueryContext" interface -func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { +func (s *stmt) QueryContext(_ context.Context, args []driver.NamedValue) (driver.Rows, error) { s.conn.Lock() defer s.conn.Unlock() if s.conn.closed { panic("conn is closed") } - // TODO: Fix context handling - rows, err := s.st.QueryContext(context.Background(), mapNamedArgs(args)...) + + ctx, cancel := utils.ContextFromChan(s.abort) + defer cancel() + + rows, err := s.st.QueryContext(ctx, mapNamedArgs(args)...) if err != nil { return nil, err } diff --git a/pkg/pg/txdb_test.go b/pkg/pg/txdb_test.go index de0157999..431a6746f 100644 --- a/pkg/pg/txdb_test.go +++ b/pkg/pg/txdb_test.go @@ -1,6 +1,7 @@ package pg import ( + "database/sql" "os" "testing" "time" @@ -52,4 +53,11 @@ func TestTxDBDriver(t *testing.T) { time.Sleep(time.Second * 10) ensureValuesPresent(t, db) }) + + t.Run("Make sure calling sql.Register() can be called twice", func(t *testing.T) { + require.NoError(t, RegisterTxDb(tests.Context(t), "foo")) + require.NoError(t, RegisterTxDb(tests.Context(t), "bar")) + drivers := sql.Drivers() + assert.Contains(t, drivers, "txdb") + }) }