diff --git a/pkg/txn/client/client.go b/pkg/txn/client/client.go index 3c44856d76398..45899f4e20186 100644 --- a/pkg/txn/client/client.go +++ b/pkg/txn/client/client.go @@ -15,8 +15,10 @@ package client import ( + "bytes" "context" "encoding/hex" + "errors" "math" "runtime/debug" "sync" @@ -336,15 +338,11 @@ func (client *txnClient) doCreateTxn( cb(op) } - ts, err := client.determineTxnSnapshot(minTS) - if err != nil { - _ = op.Rollback(ctx) - return nil, err - } + ts := client.determineTxnSnapshot(minTS) if !op.opts.skipWaitPushClient { if err := op.UpdateSnapshot(ctx, ts); err != nil { _ = op.Rollback(ctx) - return nil, err + return nil, errors.Join(err, moerr.NewTxnError(ctx, "update txn snapshot")) } } @@ -356,7 +354,7 @@ func (client *txnClient) doCreateTxn( if err := op.waitActive(ctx); err != nil { _ = op.Rollback(ctx) - return nil, err + return nil, errors.Join(err, moerr.NewTxnError(ctx, "wait active")) } return op, nil } @@ -440,7 +438,7 @@ func (client *txnClient) updateLastCommitTS(event TxnEvent) { // determineTxnSnapshot assuming we determine the timestamp to be ts, the final timestamp // returned will be ts+1. This is because we need to see the submitted data for ts, and the // timestamp for all things is ts+1. -func (client *txnClient) determineTxnSnapshot(minTS timestamp.Timestamp) (timestamp.Timestamp, error) { +func (client *txnClient) determineTxnSnapshot(minTS timestamp.Timestamp) timestamp.Timestamp { start := time.Now() defer func() { v2.TxnDetermineSnapshotDurationHistogram.Observe(time.Since(start).Seconds()) @@ -457,7 +455,7 @@ func (client *txnClient) determineTxnSnapshot(minTS timestamp.Timestamp) (timest minTS = client.adjustTimestamp(minTS) } - return minTS, nil + return minTS } func (client *txnClient) adjustTimestamp(ts timestamp.Timestamp) timestamp.Timestamp { @@ -585,6 +583,8 @@ func (client *txnClient) closeTxn(event TxnEvent) { op.notifyActive() } } + } else if ok = client.removeFromWaitActiveLocked(txn.ID); ok { + client.removeFromLeakCheck(txn.ID) } else { client.logger.Warn("txn closed", zap.String("txn ID", hex.EncodeToString(txn.ID)), @@ -752,3 +752,17 @@ func (client *txnClient) handleMarkActiveTxnAborted( } } } + +func (client *txnClient) removeFromWaitActiveLocked(txnID []byte) bool { + var ok bool + values := client.mu.waitActiveTxns[:0] + for _, op := range client.mu.waitActiveTxns { + if bytes.Equal(op.reset.txnID, txnID) { + ok = true + continue + } + values = append(values, op) + } + client.mu.waitActiveTxns = values + return ok +} diff --git a/pkg/txn/client/client_test.go b/pkg/txn/client/client_test.go index fb206b94d2f3b..cc0a73b49ae58 100644 --- a/pkg/txn/client/client_test.go +++ b/pkg/txn/client/client_test.go @@ -232,8 +232,14 @@ func TestMaxActiveTxnWithWaitTimeout(t *testing.T) { defer cancel() _, err = tc.New(ctx2, newTestTimestamp(0), WithUserTxn()) require.Error(t, err) + + v := tc.(*txnClient) + v.mu.Lock() + defer v.mu.Unlock() + require.Equal(t, 0, len(v.mu.waitActiveTxns)) }, - WithMaxActiveTxn(1)) + WithMaxActiveTxn(1), + ) } func TestOpenTxnWithWaitPausedDisabled(t *testing.T) { @@ -245,3 +251,27 @@ func TestOpenTxnWithWaitPausedDisabled(t *testing.T) { require.Error(t, c.openTxn(op)) } + +func TestNewWithUpdateSnapshotTimeout(t *testing.T) { + rt := runtime.NewRuntime(metadata.ServiceType_CN, "", + logutil.GetPanicLogger(), + runtime.WithClock(clock.NewHLCClock(func() int64 { + return 1 + }, 0))) + runtime.SetupServiceBasedRuntime("", rt) + c := NewTxnClient( + "", + newTestTxnSender(), + WithEnableSacrificingFreshness(), + WithTimestampWaiter(NewTimestampWaiter(rt.Logger())), + ) + c.Resume() + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + _, err := c.New(ctx, newTestTimestamp(10000)) + assert.Error(t, err) + v := c.(*txnClient) + v.mu.Lock() + assert.Equal(t, 0, len(v.mu.waitActiveTxns)) + v.mu.Unlock() +}