From ca2d5c6f46e48c4bba2b0702c541ded782a990a9 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Mon, 11 Nov 2024 12:52:51 +0400 Subject: [PATCH 1/4] sql: add Database.Connection/WithConnection, interface cleanup This adds a possibility to take a connection from the pool to use it via the Executor interface, and return it later when it's no longer needed. This avoids connection pool overhead in cases when a lot of quries need to be made, but the use of read transactions is not needed. Using read transactions instead of simple connections has the side effect of blocking WAL checkpoints. --- sql/database.go | 136 +++++++++++++++--- sql/database_test.go | 37 ++++- sql/interface.go | 5 +- sql/migrations.go | 5 - sql/mocks.go | 38 ----- sql/schema.go | 19 +-- .../migrations/state_0021_migration.go | 4 - .../migrations/state_0025_migration.go | 4 - 8 files changed, 163 insertions(+), 85 deletions(-) diff --git a/sql/database.go b/sql/database.go index 94a9322563..1620f633fc 100644 --- a/sql/database.go +++ b/sql/database.go @@ -572,24 +572,83 @@ type Interceptor func(query string) error type Database interface { Executor QueryCache + // Close closes the database. Close() error + // QueryCount returns the number of queries executed on the database. QueryCount() int + // QueryCache returns the query cache for this database, if it's present, + // or nil otherwise. QueryCache() QueryCache + // Tx creates deferred sqlite transaction. + // + // Deferred transactions are not started until the first statement. + // Transaction may be started in read mode and automatically upgraded to write mode + // after one of the write statements. + // + // https://www.sqlite.org/lang_transaction.html Tx(ctx context.Context) (Transaction, error) + // WithTx starts a new transaction and passes it to the exec function. + // It then commits the transaction if the exec function doesn't return an error, + // and rolls it back otherwise. + // If the context is canceled, the currently running SQL statement is interrupted. WithTx(ctx context.Context, exec func(Transaction) error) error + // TxImmediate begins a new immediate transaction on the database, that is, + // a transaction that starts a write immediately without waiting for a write + // statement. + // The transaction returned from this function must always be released by calling + // its Release method. Release rolls back the transaction if it hasn't been + // committed. + // If the context is canceled, the currently running SQL statement is interrupted. TxImmediate(ctx context.Context) (Transaction, error) + // WithTxImmediate starts a new immediate transaction and passes it to the exec + // function. + // An immediate transaction is started immediately, without waiting for a write + // statement. + // It then commits the transaction if the exec function doesn't return an error, + // and rolls it back otherwise. + // If the context is canceled, the currently running SQL statement is interrupted. WithTxImmediate(ctx context.Context, exec func(Transaction) error) error + // Connection returns a connection from the database pool. + // If many queries are to be executed in a row, but there's no need for an + // explicit transaction which may be long-running and thus block + // WAL checkpointing, it may be preferable to use a single connection for + // it to avoid database pool overhead. + // The connection needs to be always returned to the pool by calling its Release + // method. + // If the context is canceled, the currently running SQL statement is interrupted. + Connection(ctx context.Context) (Connection, error) + // WithConnection executes the provided function with a connection from the + // database pool. + // The connection is released back to the pool after the function returns. + // If the context is canceled, the currently running SQL statement is interrupted. + WithConnection(ctx context.Context, exec func(Connection) error) error + // Intercept adds an interceptor function to the database. The interceptor + // functions are invoked upon each query on the database, including queries + // executed within transactions. + // The query will fail if the interceptor returns an error. + // The interceptor can later be removed using RemoveInterceptor with the same key. Intercept(key string, fn Interceptor) + // RemoveInterceptor removes the interceptor function with specified key from the database. RemoveInterceptor(key string) } // Transaction represents a transaction. type Transaction interface { Executor + // Commit commits the transaction. Commit() error + // Release releases the transaction. If the transaction hasn't been committed, + // it's rolled back. Release() error } +// Connection represents a database connection. +type Connection interface { + Executor + // Release releases the connection back to the connection pool. + Release() +} + type sqliteDatabase struct { *queryCache pool *sqlitex.Pool @@ -684,34 +743,22 @@ func (db *sqliteDatabase) startExclusive() error { return nil } -// Tx creates deferred sqlite transaction. -// -// Deferred transactions are not started until the first statement. -// Transaction may be started in read mode and automatically upgraded to write mode -// after one of the write statements. -// -// https://www.sqlite.org/lang_transaction.html +// Tx implements Database. func (db *sqliteDatabase) Tx(ctx context.Context) (Transaction, error) { return db.getTx(ctx, beginDefault) } -// WithTx will pass initialized deferred transaction to exec callback. -// Will commit only if error is nil. +// WithTx implements Database. func (db *sqliteDatabase) WithTx(ctx context.Context, exec func(Transaction) error) error { return db.withTx(ctx, beginDefault, exec) } -// TxImmediate creates immediate transaction. -// -// IMMEDIATE cause the database connection to start a new write immediately, without waiting -// for a write statement. The BEGIN IMMEDIATE might fail with SQLITE_BUSY if another write -// transaction is already active on another database connection. +// TxImmediate implements Database. func (db *sqliteDatabase) TxImmediate(ctx context.Context) (Transaction, error) { return db.getTx(ctx, beginImmediate) } -// WithTxImmediate will pass initialized immediate transaction to exec callback. -// Will commit only if error is nil. +// WithTxImmediate implements Database. func (db *sqliteDatabase) WithTxImmediate(ctx context.Context, exec func(Transaction) error) error { return db.withTx(ctx, beginImmediate, exec) } @@ -727,7 +774,7 @@ func (db *sqliteDatabase) runInterceptors(query string) error { return nil } -// Exec statement using one of the connection from the pool. +// Exec implements Executor. // // If you care about atomicity of the operation (for example writing rewards to multiple accounts) // Tx should be used. Otherwise sqlite will not guarantee that all side-effects of operations are @@ -758,7 +805,7 @@ func (db *sqliteDatabase) Exec(query string, encoder Encoder, decoder Decoder) ( return exec(conn, query, encoder, decoder) } -// Close closes all pooled connections. +// Close implements Database. func (db *sqliteDatabase) Close() error { db.closeMux.Lock() defer db.closeMux.Unlock() @@ -772,6 +819,30 @@ func (db *sqliteDatabase) Close() error { return nil } +// Connection implements Database. +func (db *sqliteDatabase) Connection(ctx context.Context) (Connection, error) { + if db.closed { + return nil, ErrClosed + } + conCtx, cancel := context.WithCancel(ctx) + conn := db.getConn(conCtx) + if conn == nil { + cancel() + return nil, ErrNoConnection + } + return &sqliteConn{queryCache: db.queryCache, db: db, conn: conn, freeConn: cancel}, nil +} + +// WithConnection implements Database. +func (db *sqliteDatabase) WithConnection(ctx context.Context, exec func(Connection) error) error { + conn, err := db.Connection(ctx) + if err != nil { + return err + } + defer conn.Release() + return exec(conn) +} + // Intercept adds an interceptor function to the database. The interceptor functions // are invoked upon each query. The query will fail if the interceptor returns an error. // The interceptor can later be removed using RemoveInterceptor with the same key. @@ -1093,6 +1164,35 @@ func (tx *sqliteTx) Exec(query string, encoder Encoder, decoder Decoder) (int, e return exec(tx.conn, query, encoder, decoder) } +type sqliteConn struct { + *queryCache + db *sqliteDatabase + conn *sqlite.Conn + freeConn func() +} + +var _ Connection = &sqliteConn{} + +func (c *sqliteConn) Release() { + c.freeConn() + c.db.pool.Put(c.conn) +} + +func (c *sqliteConn) Exec(query string, encoder Encoder, decoder Decoder) (int, error) { + if err := c.db.runInterceptors(query); err != nil { + return 0, fmt.Errorf("running query interceptors: %w", err) + } + + c.db.queryCount.Add(1) + if c.db.latency != nil { + start := time.Now() + defer func() { + c.db.latency.WithLabelValues(query).Observe(float64(time.Since(start))) + }() + } + return exec(c.conn, query, encoder, decoder) +} + func mapSqliteError(err error) error { switch sqlite.ErrCode(err) { case sqlite.SQLITE_CONSTRAINT_PRIMARYKEY, sqlite.SQLITE_CONSTRAINT_UNIQUE: diff --git a/sql/database_test.go b/sql/database_test.go index d197d5e497..899ef2b493 100644 --- a/sql/database_test.go +++ b/sql/database_test.go @@ -93,8 +93,6 @@ func Test_Migration_Rollback(t *testing.T) { migration1.EXPECT().Apply(gomock.Any(), gomock.Any()).Return(nil) migration2.EXPECT().Apply(gomock.Any(), gomock.Any()).Return(errors.New("migration 2 failed")) - migration2.EXPECT().Rollback().Return(nil) - dbFile := filepath.Join(t.TempDir(), "test.sql") _, err := Open("file:"+dbFile, WithDatabaseSchema(&Schema{ @@ -129,7 +127,6 @@ func Test_Migration_Rollback_Only_NewMigrations(t *testing.T) { migration2.EXPECT().Name().Return("test").AnyTimes() migration2.EXPECT().Order().Return(2).AnyTimes() migration2.EXPECT().Apply(gomock.Any(), gomock.Any()).Return(errors.New("migration 2 failed")) - migration2.EXPECT().Rollback().Return(nil) _, err = Open("file:"+dbFile, WithLogger(logger), @@ -638,3 +635,37 @@ func TestExclusive(t *testing.T) { }) } } + +func TestConnection(t *testing.T) { + db := InMemoryTest(t) + c, err := db.Connection(context.Background()) + require.NoError(t, err) + var r int + n, err := c.Exec("select ?", func(stmt *Statement) { + stmt.BindInt64(1, 42) + }, func(stmt *Statement) bool { + r = stmt.ColumnInt(0) + return true + }) + require.NoError(t, err) + require.Equal(t, 1, n) + require.Equal(t, 42, r) + c.Release() + + require.NoError(t, db.WithConnection(context.Background(), func(c Connection) error { + n, err := c.Exec("select ?", func(stmt *Statement) { + stmt.BindInt64(1, 42) + }, func(stmt *Statement) bool { + r = stmt.ColumnInt(0) + return true + }) + require.NoError(t, err) + require.Equal(t, 1, n) + require.Equal(t, 42, r) + return nil + })) + + require.Error(t, db.WithConnection(context.Background(), func(c Connection) error { + return errors.New("error") + })) +} diff --git a/sql/interface.go b/sql/interface.go index c9b0ee1441..14efae19c0 100644 --- a/sql/interface.go +++ b/sql/interface.go @@ -6,13 +6,16 @@ import "go.uber.org/zap" // Executor is an interface for executing raw statement. type Executor interface { + // Exec executes a statement. Exec(string, Encoder, Decoder) (int, error) } // Migration is interface for migrations provider. type Migration interface { + // Apply applies the migration. Apply(db Executor, logger *zap.Logger) error - Rollback() error + // Name returns the name of the migration. Name() string + // Order returns the sequential number of the migration. Order() int } diff --git a/sql/migrations.go b/sql/migrations.go index 92d1a1d516..b601f30685 100644 --- a/sql/migrations.go +++ b/sql/migrations.go @@ -89,11 +89,6 @@ func (m *sqlMigration) Order() int { return m.order } -func (sqlMigration) Rollback() error { - // handled by the DB itself - return nil -} - func version(db Executor) (int, error) { var current int if _, err := db.Exec("PRAGMA user_version;", nil, func(stmt *Statement) bool { diff --git a/sql/mocks.go b/sql/mocks.go index 2be336b646..ae5e3413e2 100644 --- a/sql/mocks.go +++ b/sql/mocks.go @@ -216,41 +216,3 @@ func (c *MockMigrationOrderCall) DoAndReturn(f func() int) *MockMigrationOrderCa c.Call = c.Call.DoAndReturn(f) return c } - -// Rollback mocks base method. -func (m *MockMigration) Rollback() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Rollback") - ret0, _ := ret[0].(error) - return ret0 -} - -// Rollback indicates an expected call of Rollback. -func (mr *MockMigrationMockRecorder) Rollback() *MockMigrationRollbackCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockMigration)(nil).Rollback)) - return &MockMigrationRollbackCall{Call: call} -} - -// MockMigrationRollbackCall wrap *gomock.Call -type MockMigrationRollbackCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockMigrationRollbackCall) Return(arg0 error) *MockMigrationRollbackCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockMigrationRollbackCall) Do(f func() error) *MockMigrationRollbackCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockMigrationRollbackCall) DoAndReturn(f func() error) *MockMigrationRollbackCall { - c.Call = c.Call.DoAndReturn(f) - return c -} diff --git a/sql/schema.go b/sql/schema.go index f393d7534f..144ad43eff 100644 --- a/sql/schema.go +++ b/sql/schema.go @@ -4,7 +4,6 @@ import ( "bufio" "bytes" "context" - "errors" "fmt" "io" "os" @@ -31,14 +30,17 @@ func LoadDBSchemaScript(db Executor) (string, error) { return "", err } fmt.Fprintf(&sb, "PRAGMA user_version = %d;\n", version) + // The following SQL query ensures that tables are listed first, + // ordered by name, and then all other objects, ordered by their table name + // and then by their own name. if _, err = db.Exec(` SELECT tbl_name, sql || ';' FROM sqlite_master WHERE sql IS NOT NULL AND tbl_name NOT LIKE 'sqlite_%' ORDER BY - CASE WHEN type = 'table' THEN 1 ELSE 2 END, -- ensures tables are first - tbl_name, -- tables are sorted by name, then all other objects - name -- (indexes, triggers, etc.) also by name + CASE WHEN type = 'table' THEN 1 ELSE 2 END, + tbl_name, + name `, nil, func(st *Statement) bool { fmt.Fprintln(&sb, st.ColumnText(1)) return true @@ -143,20 +145,13 @@ func (s *Schema) Migrate(logger *zap.Logger, db Database, before, vacuumState in db.Intercept("logQueries", logQueryInterceptor(logger)) defer db.RemoveInterceptor("logQueries") } - for i, m := range s.Migrations { + for _, m := range s.Migrations { if m.Order() <= before { continue } if err := db.WithTxImmediate(context.Background(), func(tx Transaction) error { if _, ok := s.skipMigration[m.Order()]; !ok { if err := m.Apply(tx, logger); err != nil { - for j := i; j >= 0 && s.Migrations[j].Order() > before; j-- { - if e := s.Migrations[j].Rollback(); e != nil { - err = errors.Join(err, fmt.Errorf("rollback %s: %w", m.Name(), e)) - break - } - } - return fmt.Errorf("apply %s: %w", m.Name(), err) } } diff --git a/sql/statesql/migrations/state_0021_migration.go b/sql/statesql/migrations/state_0021_migration.go index b88471fbeb..874e558d8a 100644 --- a/sql/statesql/migrations/state_0021_migration.go +++ b/sql/statesql/migrations/state_0021_migration.go @@ -32,10 +32,6 @@ func (*migration0021) Order() int { return 21 } -func (*migration0021) Rollback() error { - return nil -} - func (m *migration0021) Apply(db sql.Executor, logger *zap.Logger) error { if err := m.applySql(db); err != nil { return err diff --git a/sql/statesql/migrations/state_0025_migration.go b/sql/statesql/migrations/state_0025_migration.go index 71fee844fa..c3869d74c6 100644 --- a/sql/statesql/migrations/state_0025_migration.go +++ b/sql/statesql/migrations/state_0025_migration.go @@ -40,10 +40,6 @@ func (*migration0025) Order() int { return 25 } -func (*migration0025) Rollback() error { - return nil -} - func (m *migration0025) Apply(db sql.Executor, logger *zap.Logger) error { updates := map[types.NodeID][]byte{} From bb31cc63e54593ebae7d82b64148cb769d1f5c46 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Wed, 13 Nov 2024 19:20:26 +0400 Subject: [PATCH 2/4] sql: revert removing Rollback method of the Migration interface --- sql/database_test.go | 3 ++ sql/interface.go | 1 + sql/migrations.go | 5 +++ sql/mocks.go | 38 +++++++++++++++++++ sql/schema.go | 10 ++++- .../migrations/state_0021_migration.go | 4 ++ .../migrations/state_0025_migration.go | 4 ++ 7 files changed, 64 insertions(+), 1 deletion(-) diff --git a/sql/database_test.go b/sql/database_test.go index 899ef2b493..caf15fff25 100644 --- a/sql/database_test.go +++ b/sql/database_test.go @@ -93,6 +93,8 @@ func Test_Migration_Rollback(t *testing.T) { migration1.EXPECT().Apply(gomock.Any(), gomock.Any()).Return(nil) migration2.EXPECT().Apply(gomock.Any(), gomock.Any()).Return(errors.New("migration 2 failed")) + migration2.EXPECT().Rollback().Return(nil) + dbFile := filepath.Join(t.TempDir(), "test.sql") _, err := Open("file:"+dbFile, WithDatabaseSchema(&Schema{ @@ -127,6 +129,7 @@ func Test_Migration_Rollback_Only_NewMigrations(t *testing.T) { migration2.EXPECT().Name().Return("test").AnyTimes() migration2.EXPECT().Order().Return(2).AnyTimes() migration2.EXPECT().Apply(gomock.Any(), gomock.Any()).Return(errors.New("migration 2 failed")) + migration2.EXPECT().Rollback().Return(nil) _, err = Open("file:"+dbFile, WithLogger(logger), diff --git a/sql/interface.go b/sql/interface.go index 14efae19c0..6728a6b388 100644 --- a/sql/interface.go +++ b/sql/interface.go @@ -15,6 +15,7 @@ type Migration interface { // Apply applies the migration. Apply(db Executor, logger *zap.Logger) error // Name returns the name of the migration. + Rollback() error Name() string // Order returns the sequential number of the migration. Order() int diff --git a/sql/migrations.go b/sql/migrations.go index b601f30685..92d1a1d516 100644 --- a/sql/migrations.go +++ b/sql/migrations.go @@ -89,6 +89,11 @@ func (m *sqlMigration) Order() int { return m.order } +func (sqlMigration) Rollback() error { + // handled by the DB itself + return nil +} + func version(db Executor) (int, error) { var current int if _, err := db.Exec("PRAGMA user_version;", nil, func(stmt *Statement) bool { diff --git a/sql/mocks.go b/sql/mocks.go index ae5e3413e2..2be336b646 100644 --- a/sql/mocks.go +++ b/sql/mocks.go @@ -216,3 +216,41 @@ func (c *MockMigrationOrderCall) DoAndReturn(f func() int) *MockMigrationOrderCa c.Call = c.Call.DoAndReturn(f) return c } + +// Rollback mocks base method. +func (m *MockMigration) Rollback() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Rollback") + ret0, _ := ret[0].(error) + return ret0 +} + +// Rollback indicates an expected call of Rollback. +func (mr *MockMigrationMockRecorder) Rollback() *MockMigrationRollbackCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockMigration)(nil).Rollback)) + return &MockMigrationRollbackCall{Call: call} +} + +// MockMigrationRollbackCall wrap *gomock.Call +type MockMigrationRollbackCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockMigrationRollbackCall) Return(arg0 error) *MockMigrationRollbackCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockMigrationRollbackCall) Do(f func() error) *MockMigrationRollbackCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockMigrationRollbackCall) DoAndReturn(f func() error) *MockMigrationRollbackCall { + c.Call = c.Call.DoAndReturn(f) + return c +} diff --git a/sql/schema.go b/sql/schema.go index 144ad43eff..aa05416206 100644 --- a/sql/schema.go +++ b/sql/schema.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "context" + "errors" "fmt" "io" "os" @@ -145,13 +146,20 @@ func (s *Schema) Migrate(logger *zap.Logger, db Database, before, vacuumState in db.Intercept("logQueries", logQueryInterceptor(logger)) defer db.RemoveInterceptor("logQueries") } - for _, m := range s.Migrations { + for i, m := range s.Migrations { if m.Order() <= before { continue } if err := db.WithTxImmediate(context.Background(), func(tx Transaction) error { if _, ok := s.skipMigration[m.Order()]; !ok { if err := m.Apply(tx, logger); err != nil { + for j := i; j >= 0 && s.Migrations[j].Order() > before; j-- { + if e := s.Migrations[j].Rollback(); e != nil { + err = errors.Join(err, fmt.Errorf("rollback %s: %w", m.Name(), e)) + break + } + } + return fmt.Errorf("apply %s: %w", m.Name(), err) } } diff --git a/sql/statesql/migrations/state_0021_migration.go b/sql/statesql/migrations/state_0021_migration.go index 874e558d8a..b88471fbeb 100644 --- a/sql/statesql/migrations/state_0021_migration.go +++ b/sql/statesql/migrations/state_0021_migration.go @@ -32,6 +32,10 @@ func (*migration0021) Order() int { return 21 } +func (*migration0021) Rollback() error { + return nil +} + func (m *migration0021) Apply(db sql.Executor, logger *zap.Logger) error { if err := m.applySql(db); err != nil { return err diff --git a/sql/statesql/migrations/state_0025_migration.go b/sql/statesql/migrations/state_0025_migration.go index c3869d74c6..71fee844fa 100644 --- a/sql/statesql/migrations/state_0025_migration.go +++ b/sql/statesql/migrations/state_0025_migration.go @@ -40,6 +40,10 @@ func (*migration0025) Order() int { return 25 } +func (*migration0025) Rollback() error { + return nil +} + func (m *migration0025) Apply(db sql.Executor, logger *zap.Logger) error { updates := map[types.NodeID][]byte{} From 3e5a4014c74296c294bc87f95854d45ce25f56c0 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Wed, 13 Nov 2024 19:33:10 +0400 Subject: [PATCH 3/4] sql: remove Database.Connection() method, keep WithConnection() --- sql/database.go | 55 +++++++++++--------------------------------- sql/database_test.go | 19 +++------------ 2 files changed, 17 insertions(+), 57 deletions(-) diff --git a/sql/database.go b/sql/database.go index 1620f633fc..c12ab1a87a 100644 --- a/sql/database.go +++ b/sql/database.go @@ -608,20 +608,15 @@ type Database interface { // and rolls it back otherwise. // If the context is canceled, the currently running SQL statement is interrupted. WithTxImmediate(ctx context.Context, exec func(Transaction) error) error - // Connection returns a connection from the database pool. + // WithConnection executes the provided function with a connection from the + // database pool. // If many queries are to be executed in a row, but there's no need for an // explicit transaction which may be long-running and thus block // WAL checkpointing, it may be preferable to use a single connection for // it to avoid database pool overhead. - // The connection needs to be always returned to the pool by calling its Release - // method. - // If the context is canceled, the currently running SQL statement is interrupted. - Connection(ctx context.Context) (Connection, error) - // WithConnection executes the provided function with a connection from the - // database pool. // The connection is released back to the pool after the function returns. // If the context is canceled, the currently running SQL statement is interrupted. - WithConnection(ctx context.Context, exec func(Connection) error) error + WithConnection(ctx context.Context, exec func(Executor) error) error // Intercept adds an interceptor function to the database. The interceptor // functions are invoked upon each query on the database, including queries // executed within transactions. @@ -642,13 +637,6 @@ type Transaction interface { Release() error } -// Connection represents a database connection. -type Connection interface { - Executor - // Release releases the connection back to the connection pool. - Release() -} - type sqliteDatabase struct { *queryCache pool *sqlitex.Pool @@ -819,28 +807,21 @@ func (db *sqliteDatabase) Close() error { return nil } -// Connection implements Database. -func (db *sqliteDatabase) Connection(ctx context.Context) (Connection, error) { +// WithConnection implements Database. +func (db *sqliteDatabase) WithConnection(ctx context.Context, exec func(Executor) error) error { if db.closed { - return nil, ErrClosed + return ErrClosed } conCtx, cancel := context.WithCancel(ctx) conn := db.getConn(conCtx) - if conn == nil { + defer func() { cancel() - return nil, ErrNoConnection - } - return &sqliteConn{queryCache: db.queryCache, db: db, conn: conn, freeConn: cancel}, nil -} - -// WithConnection implements Database. -func (db *sqliteDatabase) WithConnection(ctx context.Context, exec func(Connection) error) error { - conn, err := db.Connection(ctx) - if err != nil { - return err + db.pool.Put(conn) + }() + if conn == nil { + return ErrNoConnection } - defer conn.Release() - return exec(conn) + return exec(&sqliteConn{queryCache: db.queryCache, db: db, conn: conn}) } // Intercept adds an interceptor function to the database. The interceptor functions @@ -1166,16 +1147,8 @@ func (tx *sqliteTx) Exec(query string, encoder Encoder, decoder Decoder) (int, e type sqliteConn struct { *queryCache - db *sqliteDatabase - conn *sqlite.Conn - freeConn func() -} - -var _ Connection = &sqliteConn{} - -func (c *sqliteConn) Release() { - c.freeConn() - c.db.pool.Put(c.conn) + db *sqliteDatabase + conn *sqlite.Conn } func (c *sqliteConn) Exec(query string, encoder Encoder, decoder Decoder) (int, error) { diff --git a/sql/database_test.go b/sql/database_test.go index caf15fff25..60b81ff9cf 100644 --- a/sql/database_test.go +++ b/sql/database_test.go @@ -641,22 +641,9 @@ func TestExclusive(t *testing.T) { func TestConnection(t *testing.T) { db := InMemoryTest(t) - c, err := db.Connection(context.Background()) - require.NoError(t, err) var r int - n, err := c.Exec("select ?", func(stmt *Statement) { - stmt.BindInt64(1, 42) - }, func(stmt *Statement) bool { - r = stmt.ColumnInt(0) - return true - }) - require.NoError(t, err) - require.Equal(t, 1, n) - require.Equal(t, 42, r) - c.Release() - - require.NoError(t, db.WithConnection(context.Background(), func(c Connection) error { - n, err := c.Exec("select ?", func(stmt *Statement) { + require.NoError(t, db.WithConnection(context.Background(), func(ex Executor) error { + n, err := ex.Exec("select ?", func(stmt *Statement) { stmt.BindInt64(1, 42) }, func(stmt *Statement) bool { r = stmt.ColumnInt(0) @@ -668,7 +655,7 @@ func TestConnection(t *testing.T) { return nil })) - require.Error(t, db.WithConnection(context.Background(), func(c Connection) error { + require.Error(t, db.WithConnection(context.Background(), func(Executor) error { return errors.New("error") })) } From 571a53e9190c5f721896dddffe04ff755a3127a1 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 14 Nov 2024 00:21:36 +0400 Subject: [PATCH 4/4] sql: fix comment --- sql/interface.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/interface.go b/sql/interface.go index 6728a6b388..13b2a1e0ce 100644 --- a/sql/interface.go +++ b/sql/interface.go @@ -14,8 +14,9 @@ type Executor interface { type Migration interface { // Apply applies the migration. Apply(db Executor, logger *zap.Logger) error - // Name returns the name of the migration. + // Rollback rolls back the migration. Rollback() error + // Name returns the name of the migration. Name() string // Order returns the sequential number of the migration. Order() int