diff --git a/domain/state.go b/domain/state.go index e2c9d0af2e1..bae196d3ac3 100644 --- a/domain/state.go +++ b/domain/state.go @@ -6,6 +6,7 @@ package domain import ( "context" "database/sql" + "fmt" "sync" "github.com/canonical/sqlair" @@ -99,9 +100,93 @@ func (st *StateBase) Prepare(query string, typeSamples ...any) (*sqlair.Statemen return stmt, nil } +// RunAtomic executes the closure function within the scope of a transaction. +// The closure is passed a AtomicContext that can be passed on to state +// functions, so that they can perform work within that same transaction. The +// closure will be retried according to the transaction retry semantics, if the +// transaction fails due to transient errors. The closure should only be used to +// perform state changes and must not be used to execute queries outside of the +// state scope. This includes performing goroutines or other async operations. +func (st *StateBase) RunAtomic(ctx context.Context, fn func(AtomicContext) error) error { + db, err := st.DB() + if err != nil { + return errors.Annotate(err, "getting database") + } + + return db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error { + // The atomicContext is created with the transaction and passed to the + // closure function. This ensures that the transaction is always + // available to the closure. Once the transaction is complete, the + // transaction is removed from the atomicContext. This is to prevent the + // transaction from being used outside of the transaction scope. This + // will prevent any references to the sqlair.TX from being held outside + // of the transaction scope. + + // TODO (stickupkid): The atomicContext can be pooled on the StateBase + // to reduce the number of allocations. Attempting to push the tx into + // the context would prevent that as a viable option. + txCtx := &atomicContext{ + Context: ctx, + tx: tx, + } + defer txCtx.close() + + return fn(txCtx) + }) +} + +// AtomicStateBase is an interface that provides a method for executing a +// closure within the scope of a transaction. +type AtomicStateBase interface { + // RunAtomic executes the closure function within the scope of a + // transaction. The closure is passed a AtomicContext that can be passed on + // to state functions, so that they can perform work within that same + // transaction. The closure will be retried according to the transaction + // retry semantics, if the transaction fails due to transient errors. The + // closure should only be used to perform state changes and must not be used + // to execute queries outside of the state scope. This includes performing + // goroutines or other async operations. + RunAtomic(ctx context.Context, fn func(AtomicContext) error) error +} + +// Run executes the closure function using the provided AtomicContext as the +// transaction context. It is expected that the closure will perform state +// changes within the transaction scope. Any errors returned from the closure +// are coerced into a standard error to prevent sqlair errors from being +// returned to the Service layer. +func Run(ctx AtomicContext, fn func(context.Context, *sqlair.TX) error) error { + txCtx, ok := ctx.(*atomicContext) + if !ok { + // If you're seeing this error, it means that the atomicContext was not + // created by RunAtomic. This is a programming error. Did you attempt to + // wrap the context in a custom context and pass it to Run? + return fmt.Errorf("programmatic error: AtomicContext is not a *atomicContext: %T", ctx) + } + + // Ensure that we can lock the context for the duration of the run function. + // This is to prevent the transaction from being removed from the context + // or the service layer from attempting to use the transaction outside of + // the transaction scope. + txCtx.mu.Lock() + defer txCtx.mu.Unlock() + + tx := txCtx.tx + if tx == nil { + // If you're seeing this error, it means that the AtomicContext was not + // created by RunAtomic. This is a programming error. Did you capture + // the AtomicContext from a RunAtomic closure and try to use it outside + // of the closure? + return fmt.Errorf("programmatic error: AtomicContext does not have a transaction") + } + + // Execute the function with the transaction. + // Coerce the error to ensure that no sql or sqlair errors are returned + // from the function and into the Service layer. + return CoerceError(fn(ctx, tx)) +} + // txnRunner is a wrapper around a database.TxnRunner that implements the // database.TxnRunner interface. -// It is used to coerce the error returned by the database.TxnRunner into a type txnRunner struct { runner database.TxnRunner } @@ -119,3 +204,29 @@ func (r *txnRunner) Txn(ctx context.Context, fn func(context.Context, *sqlair.TX func (r *txnRunner) StdTxn(ctx context.Context, fn func(context.Context, *sql.Tx) error) error { return CoerceError(r.runner.StdTxn(ctx, fn)) } + +// AtomicContext is a typed context that provides access to the database transaction +// for the duration of a transaction. +type AtomicContext interface { + context.Context +} + +// atomicContext is the concrete implementation of the AtomicContext interface. +// The atomicContext ensures that a transaction is always available to during +// the execution of a transaction. The atomicContext stores the sqlair.TX +// directly on the struct to prevent the need to fork the context during the +// transaction. The mutex prevents data-races when the transaction is removed +// from the context. +type atomicContext struct { + context.Context + + mu sync.Mutex + tx *sqlair.TX +} + +func (c *atomicContext) close() { + c.mu.Lock() + defer c.mu.Unlock() + + c.tx = nil +} diff --git a/domain/state_test.go b/domain/state_test.go index 96bb60fb0b2..2e415eaeb19 100644 --- a/domain/state_test.go +++ b/domain/state_test.go @@ -5,11 +5,17 @@ package domain import ( "context" + "database/sql" + "fmt" + "sync/atomic" + "time" "github.com/canonical/sqlair" + jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" schematesting "github.com/juju/juju/domain/schema/testing" + "github.com/juju/juju/internal/testing" ) type stateSuite struct { @@ -96,3 +102,267 @@ func (s *stateSuite) TestStateBasePrepareKeyClash(c *gc.C) { }) c.Assert(err, gc.ErrorMatches, `cannot get result: parameter with type "domain.TestType" missing, have type with same name: "domain.TestType"`) } + +func (s *stateSuite) TestStateBaseRunAtomicTransactionExists(c *gc.C) { + f := s.TxnRunnerFactory() + base := NewStateBase(f) + db, err := base.DB() + c.Assert(err, gc.IsNil) + c.Assert(db, gc.NotNil) + + // Ensure that the transaction is sent via the AtomicContext. + + var tx *sqlair.TX + err = base.RunAtomic(context.Background(), func(c AtomicContext) error { + tx = c.(*atomicContext).tx + return err + }) + c.Assert(err, jc.ErrorIsNil) + + c.Assert(tx, gc.NotNil) +} + +func (s *stateSuite) TestStateBaseRunAtomicPreventAtomicContextStoring(c *gc.C) { + f := s.TxnRunnerFactory() + base := NewStateBase(f) + db, err := base.DB() + c.Assert(err, gc.IsNil) + c.Assert(db, gc.NotNil) + + // If the AtomicContext is stored outside of the transaction, it should + // not be possible to use it to perform state changes, as the sqlair.TX + // should be removed upon completion of the transaction. + + var txCtx AtomicContext + err = base.RunAtomic(context.Background(), func(c AtomicContext) error { + txCtx = c + return err + }) + c.Assert(err, jc.ErrorIsNil) + + c.Assert(txCtx, gc.NotNil) + + // Convert the AtomicContext to the underlying type. + c.Check(txCtx.(*atomicContext).tx, gc.IsNil) +} + +func (s *stateSuite) TestStateBaseRunAtomicContextValue(c *gc.C) { + f := s.TxnRunnerFactory() + base := NewStateBase(f) + db, err := base.DB() + c.Assert(err, gc.IsNil) + c.Assert(db, gc.NotNil) + + // Ensure that the context is passed through to the AtomicContext. + + type contextKey string + var key contextKey = "key" + + ctx := context.WithValue(context.Background(), key, "hello") + + var dbCtx AtomicContext + err = base.RunAtomic(ctx, func(c AtomicContext) error { + dbCtx = c + return err + }) + c.Assert(err, jc.ErrorIsNil) + + c.Assert(dbCtx, gc.NotNil) + c.Check(dbCtx.Value(key), gc.Equals, "hello") +} + +func (s *stateSuite) TestStateBaseRunAtomicCancel(c *gc.C) { + f := s.TxnRunnerFactory() + base := NewStateBase(f) + db, err := base.DB() + c.Assert(err, gc.IsNil) + c.Assert(db, gc.NotNil) + + // Make sure that the context symantics are respected in terms of + // cancellation. + + ctx, cancel := context.WithCancel(context.Background()) + + cancel() + + err = base.RunAtomic(ctx, func(dbCtx AtomicContext) error { + c.Fatalf("should not be called") + return err + }) + c.Assert(err, jc.ErrorIs, context.Canceled) +} + +func (s *stateSuite) TestStateBaseRunAtomicWithRun(c *gc.C) { + f := s.TxnRunnerFactory() + base := NewStateBase(f) + db, err := base.DB() + c.Assert(err, gc.IsNil) + c.Assert(db, gc.NotNil) + + // Ensure that the Run method is called. + + var called bool + err = base.RunAtomic(context.Background(), func(txCtx AtomicContext) error { + return Run(txCtx, func(ctx context.Context, tx *sqlair.TX) error { + called = true + return nil + }) + }) + c.Assert(err, jc.ErrorIsNil) + c.Assert(called, jc.IsTrue) +} + +func (s *stateSuite) TestStateBaseRunAtomicWithRunMultipleTimes(c *gc.C) { + f := s.TxnRunnerFactory() + base := NewStateBase(f) + db, err := base.DB() + c.Assert(err, gc.IsNil) + c.Assert(db, gc.NotNil) + + // Ensure that the Run method is called. + + var called int + err = base.RunAtomic(context.Background(), func(txCtx AtomicContext) error { + for i := 0; i < 10; i++ { + if err := Run(txCtx, func(ctx context.Context, tx *sqlair.TX) error { + called++ + return nil + }); err != nil { + return err + } + } + return nil + }) + c.Assert(err, jc.ErrorIsNil) + c.Assert(called, gc.Equals, 10) +} + +func (s *stateSuite) TestStateBaseRunAtomicWithRunFailsConcurrently(c *gc.C) { + f := s.TxnRunnerFactory() + base := NewStateBase(f) + db, err := base.DB() + c.Assert(err, gc.IsNil) + c.Assert(db, gc.NotNil) + + // Ensure that the run methods are correctly sequenced. Although there + // is no guarantee on the order of execution after the first run. This + // is undefined behaviour. + + var called int64 + err = base.RunAtomic(context.Background(), func(txCtx AtomicContext) error { + firstErr := make(chan error) + secondErr := make(chan error) + + start := make(chan struct{}) + go func() { + err := Run(txCtx, func(ctx context.Context, tx *sqlair.TX) error { + atomic.AddInt64(&called, 1) + defer atomic.AddInt64(&called, 1) + + close(start) + + <-time.After(time.Millisecond * 100) + + return nil + }) + firstErr <- err + }() + go func() { + select { + case <-start: + case <-time.After(testing.LongWait): + secondErr <- fmt.Errorf("failed to start in time") + return + } + + err := Run(txCtx, func(ctx context.Context, tx *sqlair.TX) error { + // If the first goroutine run is still running, the called + // value will be 1. If it has completed, the called value + // will be 2. This isn't exact, but it should be good enough + // to ensure that the first run has completed. + if atomic.LoadInt64(&called) != 2 { + return fmt.Errorf("called before first run completed") + } + + atomic.AddInt64(&called, 1) + + return nil + }) + secondErr <- err + }() + + select { + case err := <-firstErr: + if err != nil { + return err + } + case <-time.After(testing.LongWait): + return fmt.Errorf("failed to complete first run in time") + } + select { + case err := <-secondErr: + return err + case <-time.After(testing.LongWait): + return fmt.Errorf("failed to complete second run in time") + } + }) + c.Assert(err, jc.ErrorIsNil) + + // Ensure that this is 3. 0 implies that it was never run, 1 implies that + // the first run was never completed, 2 shows that the first run was + // completed. Lastly 3 states that everything was run. + c.Assert(called, gc.Equals, int64(3)) +} + +func (s *stateSuite) TestStateBaseRunAtomicWithRunPreparedStatements(c *gc.C) { + f := s.TxnRunnerFactory() + base := NewStateBase(f) + db, err := base.DB() + c.Assert(err, gc.IsNil) + c.Assert(db, gc.NotNil) + + // Ensure that the Run method can use sqlair prepared statements. + + type N struct { + Name string `db:"name"` + } + + stmt, err := base.Prepare("SELECT &N.* FROM sqlite_schema WHERE name='schema'", N{}) + c.Assert(err, jc.ErrorIsNil) + + var result []N + err = base.RunAtomic(context.Background(), func(txCtx AtomicContext) error { + return Run(txCtx, func(ctx context.Context, tx *sqlair.TX) error { + return tx.Query(ctx, stmt).GetAll(&result) + }) + }) + c.Assert(err, jc.ErrorIsNil) + c.Assert(result, gc.HasLen, 1) + c.Check(result[0].Name, gc.Equals, "schema") +} + +func (s *stateSuite) TestStateBaseRunAtomicWithRunDoesNotLeakError(c *gc.C) { + f := s.TxnRunnerFactory() + base := NewStateBase(f) + db, err := base.DB() + c.Assert(err, gc.IsNil) + c.Assert(db, gc.NotNil) + + // Ensure that the Run method does not leak sql.ErrNoRows. + + type N struct { + Name string `db:"name"` + } + + stmt, err := base.Prepare("SELECT &N.* FROM sqlite_schema WHERE name='something something something'", N{}) + c.Assert(err, jc.ErrorIsNil) + + var result N + err = base.RunAtomic(context.Background(), func(txCtx AtomicContext) error { + return Run(txCtx, func(ctx context.Context, tx *sqlair.TX) error { + return tx.Query(ctx, stmt).Get(&result) + }) + }) + c.Assert(err, gc.Not(jc.ErrorIs), sql.ErrNoRows) + c.Assert(err, gc.Not(jc.ErrorIs), sqlair.ErrNoRows) +}