From 10bd41ade44af7a4f575448fd989d247327160f2 Mon Sep 17 00:00:00 2001 From: Aleksey Myasnikov Date: Thu, 12 Sep 2024 17:15:59 +0300 Subject: [PATCH 1/2] * Added option `ydb.WithSessionPoolSessionIdleTimeToLive` for restrict idle time of query sessions * Fixed bug with leak of query transactions --- CHANGELOG.md | 2 + internal/pool/pool.go | 32 ++-- internal/pool/pool_test.go | 14 +- internal/query/client.go | 27 ++- internal/query/client_test.go | 288 +++++++++++++++++++++++++++++- internal/query/config/config.go | 11 +- internal/query/config/options.go | 10 ++ internal/query/session/session.go | 39 ++-- internal/query/session/status.go | 3 + internal/query/transaction.go | 32 ++-- internal/table/client.go | 2 +- options.go | 10 ++ 12 files changed, 402 insertions(+), 68 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index df8ef0e32..fd3ddb311 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,5 @@ +* Added option `ydb.WithSessionPoolSessionIdleTimeToLive` for restrict idle time of query sessions +* Fixed bug with leak of query transactions * Changed `ydb_go_sdk_ydb_driver_conn_requests` metrics splitted to `ydb_go_sdk_ydb_driver_conn_request_statuses` and `ydb_go_sdk_ydb_driver_conn_request_methods` * Fixed metadata for operation service connection * Fixed composing query traces in call `db.Query.Do[Tx]` using option `query.WithTrace` diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 17be47081..5fb9a5215 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -26,18 +26,18 @@ type ( Item } Config[PT ItemConstraint[T], T any] struct { - trace *Trace - clock clockwork.Clock - limit int - createTimeout time.Duration - createItem func(ctx context.Context) (PT, error) - closeTimeout time.Duration - closeItem func(ctx context.Context, item PT) - idleThreshold time.Duration + trace *Trace + clock clockwork.Clock + limit int + createTimeout time.Duration + createItem func(ctx context.Context) (PT, error) + closeTimeout time.Duration + closeItem func(ctx context.Context, item PT) + idleTimeToLive time.Duration } itemInfo[PT ItemConstraint[T], T any] struct { - idle *xlist.Element[PT] - touched time.Time + idle *xlist.Element[PT] + lastUsage time.Time } waitChPool[PT ItemConstraint[T], T any] interface { GetOrNew() *chan PT @@ -99,9 +99,9 @@ func WithTrace[PT ItemConstraint[T], T any](t *Trace) Option[PT, T] { } } -func WithIdleThreshold[PT ItemConstraint[T], T any](idleThreshold time.Duration) Option[PT, T] { +func WithIdleTimeToLive[PT ItemConstraint[T], T any](idleTTL time.Duration) Option[PT, T] { return func(c *Config[PT, T]) { - c.idleThreshold = idleThreshold + c.idleTimeToLive = idleTTL } } @@ -218,7 +218,7 @@ func makeAsyncCreateItemFunc[PT ItemConstraint[T], T any]( //nolint:funlen if newItem != nil { p.mu.WithLock(func() { p.index[newItem] = itemInfo[PT, T]{ - touched: p.config.clock.Now(), + lastUsage: p.config.clock.Now(), } }) } @@ -461,7 +461,7 @@ func (p *Pool[PT, T]) peekFirstIdle() (item PT, touched time.Time) { panic(fmt.Sprintf("inconsistent index: (%v, %+v, %+v)", has, el, info.idle)) } - return item, info.touched + return item, info.lastUsage } // removes first session from idle and resets the keepAliveCount @@ -547,7 +547,7 @@ func (p *Pool[PT, T]) pushIdle(item PT, now time.Time) { } p.changeState(func() Stats { - info.touched = now + info.lastUsage = now info.idle = p.idle.PushBack(item) p.index[item] = info @@ -595,7 +595,7 @@ func (p *Pool[PT, T]) getItem(ctx context.Context) (item PT, finalErr error) { / return info }) - if p.config.idleThreshold > 0 && p.config.clock.Since(info.touched) > p.config.idleThreshold { + if p.config.idleTimeToLive > 0 && p.config.clock.Since(info.lastUsage) > p.config.idleTimeToLive { p.closeItem(ctx, item) p.mu.WithLock(func() { p.changeState(func() Stats { diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index d22c48a66..2c6c58639 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -394,7 +394,7 @@ func TestPool(t *testing.T) { // replace default async closer for sync testing WithSyncCloseItem[*testItem, testItem](), WithClock[*testItem, testItem](fakeClock), - WithIdleThreshold[*testItem, testItem](idleThreshold), + WithIdleTimeToLive[*testItem, testItem](idleThreshold), WithTrace[*testItem, testItem](defaultTrace), ) @@ -402,14 +402,14 @@ func TestPool(t *testing.T) { s2 := mustGetItem(t, p) // Put both items at the absolutely same time. - // That is, both items must be updated their touched timestamp. + // That is, both items must be updated their lastUsage timestamp. mustPutItem(t, p, s1) mustPutItem(t, p, s2) require.Len(t, p.index, 2) require.Equal(t, 2, p.idle.Len()) - // Move clock to longer than idleThreshold + // Move clock to longer than idleTimeToLive fakeClock.Advance(idleThreshold + time.Nanosecond) // on get item from idle list the pool must check the item idle timestamp @@ -423,15 +423,15 @@ func TestPool(t *testing.T) { t.Fatal("unexpected number of closed items") } - // Move time to idleThreshold / 2 - this emulate a "spent" some time working within item. + // Move time to idleTimeToLive / 2 - this emulate a "spent" some time working within item. fakeClock.Advance(idleThreshold / 2) // Now put that item back - // pool must update a touched timestamp of item + // pool must update a lastUsage timestamp of item mustPutItem(t, p, s3) - // Move time to idleThreshold / 2 - // Total time since last updating touched timestampe is more than idleThreshold + // Move time to idleTimeToLive / 2 + // Total time since last updating lastUsage timestampe is more than idleTimeToLive fakeClock.Advance(idleThreshold/2 + time.Nanosecond) require.Len(t, p.index, 1) diff --git a/internal/query/client.go b/internal/query/client.go index 5f9f04104..3091de0d4 100644 --- a/internal/query/client.go +++ b/internal/query/client.go @@ -205,9 +205,7 @@ func do( err := op(ctx, s) if err != nil { - if xerrors.IsOperationError(err) { - s.SetStatus(session.StatusClosed) - } + s.SetStatus(session.StatusError) return xerrors.WithStackTrace(err) } @@ -263,27 +261,27 @@ func doTx( txSettings tx.Settings, opts ...retry.Option, ) (finalErr error) { - err := do(ctx, pool, func(ctx context.Context, s *Session) (err error) { + err := do(ctx, pool, func(ctx context.Context, s *Session) (opErr error) { tx, err := s.Begin(ctx, txSettings) if err != nil { return xerrors.WithStackTrace(err) } - err = op(ctx, tx) - if err != nil { - errRollback := tx.Rollback(ctx) - if errRollback != nil { - return xerrors.WithStackTrace(xerrors.Join(err, errRollback)) + + defer func() { + _ = tx.Rollback(ctx) + + if opErr != nil { + s.SetStatus(session.StatusError) } + }() + err = op(ctx, tx) + if err != nil { return xerrors.WithStackTrace(err) } + err = tx.CommitTx(ctx) if err != nil { - errRollback := tx.Rollback(ctx) - if errRollback != nil { - return xerrors.WithStackTrace(xerrors.Join(err, errRollback)) - } - return xerrors.WithStackTrace(err) } @@ -530,6 +528,7 @@ func New(ctx context.Context, cc grpc.ClientConnInterface, cfg *config.Config) * pool.WithTrace[*Session, Session](poolTrace(cfg.Trace())), pool.WithCreateItemTimeout[*Session, Session](cfg.SessionCreateTimeout()), pool.WithCloseItemTimeout[*Session, Session](cfg.SessionDeleteTimeout()), + pool.WithIdleTimeToLive[*Session, Session](cfg.SessionIdleTimeToLive()), pool.WithCreateItemFunc(func(ctx context.Context) (_ *Session, err error) { var ( createCtx context.Context diff --git a/internal/query/client_test.go b/internal/query/client_test.go index eb6d494e1..1ba9fb625 100644 --- a/internal/query/client_test.go +++ b/internal/query/client_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "io" + "math/rand" "testing" "github.com/stretchr/testify/require" @@ -13,6 +14,7 @@ import ( "github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Query" "github.com/ydb-platform/ydb-go-genproto/protos/Ydb_TableStats" "go.uber.org/mock/gomock" + "google.golang.org/grpc" grpcCodes "google.golang.org/grpc/codes" grpcStatus "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/anypb" @@ -279,6 +281,284 @@ func TestClient(t *testing.T) { require.NoError(t, err) require.Equal(t, 10, counter) }) + t.Run("TxLeak", func(t *testing.T) { + t.Run("OnExec", func(t *testing.T) { + t.Run("WithExplicitCommit", func(t *testing.T) { + xtest.TestManyTimes(t, func(t testing.TB) { + txInFlight := 0 + ctrl := gomock.NewController(t) + err := doTx(ctx, testPool(ctx, func(ctx context.Context) (*Session, error) { + client := NewMockQueryServiceClient(ctrl) + client.EXPECT().ExecuteQuery(gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, request *Ydb_Query.ExecuteQueryRequest, option ...grpc.CallOption) ( + Ydb_Query_V1.QueryService_ExecuteQueryClient, error, + ) { + if rand.Int31n(100) < 50 { + return nil, xerrors.Operation(xerrors.WithStatusCode(Ydb.StatusIds_BAD_SESSION)) + } + + stream := NewMockQueryService_ExecuteQueryClient(ctrl) + stream.EXPECT().Recv().DoAndReturn(func() (*Ydb_Query.ExecuteQueryResponsePart, error) { + txInFlight++ + + stream.EXPECT().Recv().Return(nil, io.EOF) + + client.EXPECT().CommitTransaction(gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, request *Ydb_Query.CommitTransactionRequest, option ...grpc.CallOption) ( + *Ydb_Query.CommitTransactionResponse, error, + ) { + txInFlight-- + + return &Ydb_Query.CommitTransactionResponse{ + Status: Ydb.StatusIds_SUCCESS, + }, nil + }) + + return &Ydb_Query.ExecuteQueryResponsePart{ + Status: Ydb.StatusIds_SUCCESS, + TxMeta: &Ydb_Query.TransactionMeta{ + Id: "456", + }, + ExecStats: &Ydb_TableStats.QueryStats{}, + }, nil + }) + + return stream, nil + }) + + return newTestSessionWithClient("123", client), nil + }), func(ctx context.Context, tx query.TxActor) error { + return tx.Exec(ctx, "") + }, tx.NewSettings(tx.WithSerializableReadWrite())) + require.NoError(t, err) + require.Zero(t, txInFlight) + }) + }) + t.Run("WithLazyCommit", func(t *testing.T) { + xtest.TestManyTimes(t, func(t testing.TB) { + ctrl := gomock.NewController(t) + txInFlight := 0 + err := doTx(ctx, testPool(ctx, func(ctx context.Context) (*Session, error) { + client := NewMockQueryServiceClient(ctrl) + client.EXPECT().ExecuteQuery(gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, request *Ydb_Query.ExecuteQueryRequest, option ...grpc.CallOption) ( + Ydb_Query_V1.QueryService_ExecuteQueryClient, error, + ) { + require.True(t, request.GetTxControl().GetCommitTx()) + + if rand.Int31n(100) < 50 { + return nil, xerrors.Operation(xerrors.WithStatusCode(Ydb.StatusIds_BAD_SESSION)) + } + + stream := NewMockQueryService_ExecuteQueryClient(ctrl) + stream.EXPECT().Recv().DoAndReturn(func() (*Ydb_Query.ExecuteQueryResponsePart, error) { + if rand.Int31n(100) < 50 { + return nil, xerrors.Operation(xerrors.WithStatusCode(Ydb.StatusIds_BAD_SESSION)) + } + + txInFlight++ + + stream.EXPECT().Recv().DoAndReturn(func() (*Ydb_Query.ExecuteQueryResponsePart, error) { + txInFlight-- + + return nil, io.EOF + }) + + return &Ydb_Query.ExecuteQueryResponsePart{ + Status: Ydb.StatusIds_SUCCESS, + TxMeta: &Ydb_Query.TransactionMeta{ + Id: "456", + }, + ExecStats: &Ydb_TableStats.QueryStats{}, + }, nil + }) + + return stream, nil + }) + + return newTestSessionWithClient("123", client), nil + }), func(ctx context.Context, tx query.TxActor) error { + return tx.Exec(ctx, "", options.WithCommit()) + }, tx.NewSettings(tx.WithSerializableReadWrite())) + require.NoError(t, err) + require.Zero(t, txInFlight) + }) + }) + }) + t.Run("OnSecondExec", func(t *testing.T) { + t.Run("WithExplicitCommit", func(t *testing.T) { + xtest.TestManyTimes(t, func(t testing.TB) { + ctrl := gomock.NewController(t) + txInFlight := 0 + err := doTx(ctx, testPool(ctx, func(ctx context.Context) (*Session, error) { + client := NewMockQueryServiceClient(ctrl) + client.EXPECT().ExecuteQuery(gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, request *Ydb_Query.ExecuteQueryRequest, option ...grpc.CallOption) ( + Ydb_Query_V1.QueryService_ExecuteQueryClient, error, + ) { + if rand.Int31n(100) < 50 { + return nil, xerrors.Operation(xerrors.WithStatusCode(Ydb.StatusIds_BAD_SESSION)) + } + + firstStream := NewMockQueryService_ExecuteQueryClient(ctrl) + firstStream.EXPECT().Recv().DoAndReturn(func() (*Ydb_Query.ExecuteQueryResponsePart, error) { + txInFlight++ + + firstStream.EXPECT().Recv().DoAndReturn(func() (*Ydb_Query.ExecuteQueryResponsePart, error) { + client.EXPECT().ExecuteQuery(gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, request *Ydb_Query.ExecuteQueryRequest, option ...grpc.CallOption) ( + Ydb_Query_V1.QueryService_ExecuteQueryClient, error, + ) { + if rand.Int31n(100) < 50 { + client.EXPECT().RollbackTransaction(gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, + request *Ydb_Query.RollbackTransactionRequest, + option ...grpc.CallOption, + ) (*Ydb_Query.RollbackTransactionResponse, error) { + txInFlight-- + + return &Ydb_Query.RollbackTransactionResponse{}, nil + }) + + return nil, xerrors.Operation(xerrors.WithStatusCode(Ydb.StatusIds_BAD_SESSION)) + } + + secondStream := NewMockQueryService_ExecuteQueryClient(ctrl) + secondStream.EXPECT().Recv().DoAndReturn(func() (*Ydb_Query.ExecuteQueryResponsePart, error) { + secondStream.EXPECT().Recv().DoAndReturn(func() (*Ydb_Query.ExecuteQueryResponsePart, error) { + client.EXPECT().CommitTransaction(gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, request *Ydb_Query.CommitTransactionRequest, option ...grpc.CallOption) ( + *Ydb_Query.CommitTransactionResponse, error, + ) { + txInFlight-- + + return &Ydb_Query.CommitTransactionResponse{ + Status: Ydb.StatusIds_SUCCESS, + }, nil + }) + + return nil, io.EOF + }) + + return &Ydb_Query.ExecuteQueryResponsePart{ + Status: Ydb.StatusIds_SUCCESS, + TxMeta: &Ydb_Query.TransactionMeta{}, + ExecStats: &Ydb_TableStats.QueryStats{}, + }, nil + }) + + return secondStream, nil + }) + + return nil, io.EOF + }) + + return &Ydb_Query.ExecuteQueryResponsePart{ + Status: Ydb.StatusIds_SUCCESS, + TxMeta: &Ydb_Query.TransactionMeta{ + Id: "456", + }, + ExecStats: &Ydb_TableStats.QueryStats{}, + }, nil + }) + + return firstStream, nil + }) + + return newTestSessionWithClient("123", client), nil + }), func(ctx context.Context, tx query.TxActor) error { + if err := tx.Exec(ctx, ""); err != nil { + return err + } + + return tx.Exec(ctx, "") + }, tx.NewSettings(tx.WithSerializableReadWrite())) + require.NoError(t, err) + }) + }) + t.Run("WithLazyCommit", func(t *testing.T) { + xtest.TestManyTimes(t, func(t testing.TB) { + ctrl := gomock.NewController(t) + txInFlight := 0 + err := doTx(ctx, testPool(ctx, func(ctx context.Context) (*Session, error) { + client := NewMockQueryServiceClient(ctrl) + client.EXPECT().ExecuteQuery(gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, request *Ydb_Query.ExecuteQueryRequest, option ...grpc.CallOption) ( + Ydb_Query_V1.QueryService_ExecuteQueryClient, error, + ) { + if rand.Int31n(100) < 50 { + return nil, xerrors.Operation(xerrors.WithStatusCode(Ydb.StatusIds_BAD_SESSION)) + } + + firstStream := NewMockQueryService_ExecuteQueryClient(ctrl) + firstStream.EXPECT().Recv().DoAndReturn(func() (*Ydb_Query.ExecuteQueryResponsePart, error) { + txInFlight++ + + firstStream.EXPECT().Recv().DoAndReturn(func() (*Ydb_Query.ExecuteQueryResponsePart, error) { + client.EXPECT().ExecuteQuery(gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, request *Ydb_Query.ExecuteQueryRequest, option ...grpc.CallOption) ( + Ydb_Query_V1.QueryService_ExecuteQueryClient, error, + ) { + require.True(t, request.GetTxControl().GetCommitTx()) + + if rand.Int31n(100) < 50 { + client.EXPECT().RollbackTransaction(gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, + request *Ydb_Query.RollbackTransactionRequest, + option ...grpc.CallOption, + ) (*Ydb_Query.RollbackTransactionResponse, error) { + txInFlight-- + + return &Ydb_Query.RollbackTransactionResponse{}, nil + }) + + return nil, xerrors.Operation(xerrors.WithStatusCode(Ydb.StatusIds_BAD_SESSION)) + } + + secondStream := NewMockQueryService_ExecuteQueryClient(ctrl) + secondStream.EXPECT().Recv().DoAndReturn(func() (*Ydb_Query.ExecuteQueryResponsePart, error) { + secondStream.EXPECT().Recv().DoAndReturn(func() (*Ydb_Query.ExecuteQueryResponsePart, error) { + return nil, io.EOF + }) + + return &Ydb_Query.ExecuteQueryResponsePart{ + Status: Ydb.StatusIds_SUCCESS, + TxMeta: &Ydb_Query.TransactionMeta{}, + ExecStats: &Ydb_TableStats.QueryStats{}, + }, nil + }) + + return secondStream, nil + }) + + return nil, io.EOF + }) + + return &Ydb_Query.ExecuteQueryResponsePart{ + Status: Ydb.StatusIds_SUCCESS, + TxMeta: &Ydb_Query.TransactionMeta{ + Id: "456", + }, + ExecStats: &Ydb_TableStats.QueryStats{}, + }, nil + }) + + return firstStream, nil + }) + + return newTestSessionWithClient("123", client), nil + }), func(ctx context.Context, tx query.TxActor) error { + if err := tx.Exec(ctx, ""); err != nil { + return err + } + + return tx.Exec(ctx, "", options.WithCommit()) + }, tx.NewSettings(tx.WithSerializableReadWrite())) + require.NoError(t, err) + }) + }) + }) + }) }) t.Run("Exec", func(t *testing.T) { t.Run("HappyWay", func(t *testing.T) { @@ -1151,12 +1431,7 @@ type sessionControllerMock struct { } func (s *sessionControllerMock) IsAlive() bool { - switch s.status { - case session.StatusClosed, session.StatusClosing: - return false - default: - return true - } + return session.IsAlive(s.status) } func (s *sessionControllerMock) Close(ctx context.Context) error { @@ -1201,6 +1476,7 @@ func testPool( return pool.New[*Session, Session](ctx, pool.WithLimit[*Session, Session](1), pool.WithCreateItemFunc(createSession), + pool.WithSyncCloseItem[*Session, Session](), ) } diff --git a/internal/query/config/config.go b/internal/query/config/config.go index 7adb08242..befa08447 100644 --- a/internal/query/config/config.go +++ b/internal/query/config/config.go @@ -19,8 +19,9 @@ type Config struct { poolLimit int - sessionCreateTimeout time.Duration - sessionDeleteTimeout time.Duration + sessionCreateTimeout time.Duration + sessionDeleteTimeout time.Duration + sessionIddleTimeToLive time.Duration trace *trace.Query } @@ -68,3 +69,9 @@ func (c *Config) SessionCreateTimeout() time.Duration { func (c *Config) SessionDeleteTimeout() time.Duration { return c.sessionDeleteTimeout } + +// SessionIdleTimeToLive limits maximum time to live of idle session +// If idleTimeToLive is less than or equal to zero then sessions will not be closed by idle +func (c *Config) SessionIdleTimeToLive() time.Duration { + return c.sessionIddleTimeToLive +} diff --git a/internal/query/config/options.go b/internal/query/config/options.go index 2b30c5be2..01e040da1 100644 --- a/internal/query/config/options.go +++ b/internal/query/config/options.go @@ -55,3 +55,13 @@ func WithSessionDeleteTimeout(deleteTimeout time.Duration) Option { } } } + +// WithSessionIdleTimeToLive limits maximum time to live of idle session +// If idleTimeToLive is less than or equal to zero then sessions will not be closed by idle +func WithSessionIdleTimeToLive(idleTimeToLive time.Duration) Option { + return func(c *Config) { + if idleTimeToLive > 0 { + c.sessionIddleTimeToLive = idleTimeToLive + } + } +} diff --git a/internal/query/session/session.go b/internal/query/session/session.go index 142728953..54612b5d7 100644 --- a/internal/query/session/session.go +++ b/internal/query/session/session.go @@ -2,7 +2,7 @@ package session import ( "context" - "sync/atomic" + "fmt" "time" "github.com/ydb-platform/ydb-go-genproto/Ydb_Query_V1" @@ -50,11 +50,26 @@ func (c *core) NodeID() uint32 { } func (c *core) statusCode() Status { - return Status(atomic.LoadUint32((*uint32)(&c.status))) + return c.status } func (c *core) SetStatus(status Status) { - atomic.StoreUint32((*uint32)(&c.status), uint32(status)) + switch c.status { + case statusUnknown: + c.status = status + case StatusIdle: + c.status = status + case StatusInUse: + c.status = status + case StatusClosing: + c.status = status + case StatusClosed: + c.status = status + case StatusError: + c.status = status + default: + panic(fmt.Sprintf("Unknown%d", c.status)) + } } func (c *core) Status() string { @@ -81,7 +96,16 @@ func WithTrace(t *trace.Query) Option { } } -func Open( //nolint:funlen +func IsAlive(status Status) bool { + switch status { + case StatusClosed, StatusClosing, StatusError: + return false + default: + return true + } +} + +func Open( ctx context.Context, client Ydb_Query_V1.QueryServiceClient, opts ...Option, ) (_ *core, finalErr error) { core := &core{ @@ -90,12 +114,7 @@ func Open( //nolint:funlen status: statusUnknown, checks: []func(s *core) bool{ func(s *core) bool { - switch s.statusCode() { - case StatusClosed, StatusClosing: - return false - default: - return true - } + return IsAlive(s.status) }, }, } diff --git a/internal/query/session/status.go b/internal/query/session/status.go index e8123dee0..5f7985626 100644 --- a/internal/query/session/status.go +++ b/internal/query/session/status.go @@ -12,6 +12,7 @@ const ( StatusInUse StatusClosing StatusClosed + StatusError ) func (s Status) String() string { @@ -26,6 +27,8 @@ func (s Status) String() string { return "Closing" case StatusClosed: return "Closed" + case StatusError: + return "Error" default: return fmt.Sprintf("Unknown%d", s) } diff --git a/internal/query/transaction.go b/internal/query/transaction.go index 50e10a403..4b43a5288 100644 --- a/internal/query/transaction.go +++ b/internal/query/transaction.go @@ -113,7 +113,7 @@ func (tx *Transaction) QueryResultSet( } if settings.TxControl().Commit { - if txID != nil { + if txID != nil && tx.Identifier != nil { return nil, xerrors.WithStackTrace(errUnexpectedTxIDOnCommitFlag) } tx.completed = true @@ -193,7 +193,7 @@ func (tx *Transaction) txControl() *queryTx.Control { func (tx *Transaction) ID() string { if tx.Identifier == nil { - return "LAZY_TX" + return LazyTxID } return tx.Identifier.ID() @@ -236,7 +236,7 @@ func (tx *Transaction) Exec(ctx context.Context, q string, opts ...options.Execu } if settings.TxControl().Commit { - if txID != nil { + if txID != nil && tx.Identifier != nil { return xerrors.WithStackTrace(errUnexpectedTxIDOnCommitFlag) } tx.completed = true @@ -322,7 +322,7 @@ func (tx *Transaction) Query(ctx context.Context, q string, opts ...options.Exec } if settings.TxControl().Commit { - if txID != nil { + if txID != nil && tx.Identifier != nil { return nil, xerrors.WithStackTrace(errUnexpectedTxIDOnCommitFlag) } tx.completed = true @@ -350,17 +350,21 @@ func commitTx(ctx context.Context, client Ydb_Query_V1.QueryServiceClient, sessi return nil } -func (tx *Transaction) CommitTx(ctx context.Context) (err error) { - defer func() { - tx.notifyOnCompleted(err) - tx.completed = true - }() - +func (tx *Transaction) CommitTx(ctx context.Context) (finalErr error) { if tx.Identifier == nil { return nil } - err = commitTx(ctx, tx.s.client, tx.s.ID(), tx.ID()) + if tx.completed { + return nil + } + + defer func() { + tx.notifyOnCompleted(finalErr) + tx.completed = true + }() + + err := commitTx(ctx, tx.s.client, tx.s.ID(), tx.ID()) if err != nil { if xerrors.IsOperationError(err, Ydb.StatusIds_BAD_SESSION) { tx.s.SetStatus(session.StatusClosed) @@ -384,11 +388,15 @@ func rollback(ctx context.Context, client Ydb_Query_V1.QueryServiceClient, sessi return nil } -func (tx *Transaction) Rollback(ctx context.Context) error { +func (tx *Transaction) Rollback(ctx context.Context) (finalErr error) { if tx.Identifier == nil { return nil } + if tx.completed { + return nil + } + tx.completed = true tx.notifyOnCompleted(ErrTransactionRollingBack) diff --git a/internal/table/client.go b/internal/table/client.go index 5ce8e08dc..d65958ca0 100644 --- a/internal/table/client.go +++ b/internal/table/client.go @@ -33,7 +33,7 @@ func New(ctx context.Context, cc grpc.ClientConnInterface, config *config.Config }, pool: pool.New[*session, session](ctx, pool.WithLimit[*session, session](config.SizeLimit()), - pool.WithIdleThreshold[*session, session](config.IdleThreshold()), + pool.WithIdleTimeToLive[*session, session](config.IdleThreshold()), pool.WithCreateItemTimeout[*session, session](config.CreateSessionTimeout()), pool.WithCloseItemTimeout[*session, session](config.DeleteTimeout()), pool.WithClock[*session, session](config.Clock()), diff --git a/options.go b/options.go index 304d1439a..668d607a1 100644 --- a/options.go +++ b/options.go @@ -501,6 +501,16 @@ func WithSessionPoolIdleThreshold(idleThreshold time.Duration) Option { } } +// WithSessionPoolSessionIdleTimeToLive limits maximum time to live of idle session +// If idleTimeToLive is less than or equal to zero then sessions will not be closed by idle +func WithSessionPoolSessionIdleTimeToLive(idleThreshold time.Duration) Option { + return func(ctx context.Context, c *Driver) error { + c.queryOptions = append(c.queryOptions, queryConfig.WithSessionIdleTimeToLive(idleThreshold)) + + return nil + } +} + // WithSessionPoolCreateSessionTimeout set timeout for new session creation process in table.Client func WithSessionPoolCreateSessionTimeout(createSessionTimeout time.Duration) Option { return func(ctx context.Context, c *Driver) error { From d26f929bd5e08276eeb05f425107c56443008be8 Mon Sep 17 00:00:00 2001 From: Aleksey Myasnikov Date: Fri, 13 Sep 2024 22:33:31 +0300 Subject: [PATCH 2/2] refactored lazy tx --- internal/query/client.go | 10 +- internal/query/client_test.go | 157 ++++++++++---------- internal/query/errors.go | 19 ++- internal/query/execute_query.go | 15 +- internal/query/execute_query_test.go | 31 +++- internal/query/result.go | 21 ++- internal/query/result_go1.23_test.go | 2 +- internal/query/result_test.go | 34 ++--- internal/query/session.go | 8 +- internal/query/session/session.go | 36 ++--- internal/query/transaction.go | 127 +++++----------- internal/query/transaction_fixtures_test.go | 8 +- internal/query/transaction_test.go | 12 +- internal/tx/id.go | 28 +++- tests/integration/query_execute_test.go | 122 +++++++++++++++ tests/integration/query_tx_execute_test.go | 23 +-- 16 files changed, 383 insertions(+), 270 deletions(-) diff --git a/internal/query/client.go b/internal/query/client.go index 3091de0d4..43e4430b7 100644 --- a/internal/query/client.go +++ b/internal/query/client.go @@ -336,12 +336,12 @@ func (c *Client) QueryRow(ctx context.Context, q string, opts ...options.Execute func clientExec(ctx context.Context, pool sessionPool, q string, opts ...options.Execute) (finalErr error) { settings := options.ExecuteSettings(opts...) err := do(ctx, pool, func(ctx context.Context, s *Session) (err error) { - _, r, err := execute(ctx, s.ID(), s.client, q, settings, withTrace(s.trace)) + streamResult, err := execute(ctx, s.ID(), s.client, q, settings, withTrace(s.trace)) if err != nil { return xerrors.WithStackTrace(err) } - err = readAll(ctx, r) + err = readAll(ctx, streamResult) if err != nil { return xerrors.WithStackTrace(err) } @@ -380,7 +380,7 @@ func clientQuery(ctx context.Context, pool sessionPool, q string, opts ...option ) { settings := options.ExecuteSettings(opts...) err = do(ctx, pool, func(ctx context.Context, s *Session) (err error) { - _, streamResult, err := execute(ctx, s.ID(), s.client, q, + streamResult, err := execute(ctx, s.ID(), s.client, q, options.ExecuteSettings(opts...), withTrace(s.trace), ) if err != nil { @@ -432,12 +432,12 @@ func clientQueryResultSet( ctx context.Context, pool sessionPool, q string, settings executeSettings, resultOpts ...resultOption, ) (rs result.ClosableResultSet, finalErr error) { err := do(ctx, pool, func(ctx context.Context, s *Session) error { - _, r, err := execute(ctx, s.ID(), s.client, q, settings, resultOpts...) + streamResult, err := execute(ctx, s.ID(), s.client, q, settings, resultOpts...) if err != nil { return xerrors.WithStackTrace(err) } - rs, err = readMaterializedResultSet(ctx, r) + rs, err = readMaterializedResultSet(ctx, streamResult) if err != nil { return xerrors.WithStackTrace(err) } diff --git a/internal/query/client_test.go b/internal/query/client_test.go index 1ba9fb625..e0ac2a8c4 100644 --- a/internal/query/client_test.go +++ b/internal/query/client_test.go @@ -175,76 +175,80 @@ func TestClient(t *testing.T) { t.Run("DoTx", func(t *testing.T) { t.Run("HappyWay", func(t *testing.T) { ctrl := gomock.NewController(t) - client := NewMockQueryServiceClient(ctrl) - stream := NewMockQueryService_ExecuteQueryClient(ctrl) - stream.EXPECT().Recv().Return(&Ydb_Query.ExecuteQueryResponsePart{ - Status: Ydb.StatusIds_SUCCESS, - TxMeta: &Ydb_Query.TransactionMeta{ - Id: "456", - }, - ResultSetIndex: 0, - ResultSet: &Ydb.ResultSet{ - Columns: []*Ydb.Column{ - { - Name: "a", - Type: &Ydb.Type{ - Type: &Ydb.Type_TypeId{ - TypeId: Ydb.Type_UINT64, - }, - }, - }, - { - Name: "b", - Type: &Ydb.Type{ - Type: &Ydb.Type_TypeId{ - TypeId: Ydb.Type_UTF8, - }, - }, + err := doTx(ctx, testPool(ctx, func(ctx context.Context) (*Session, error) { + client := NewMockQueryServiceClient(ctrl) + stream := NewMockQueryService_ExecuteQueryClient(ctrl) + stream.EXPECT().Recv().DoAndReturn(func() (*Ydb_Query.ExecuteQueryResponsePart, error) { + client.EXPECT().CommitTransaction(gomock.Any(), gomock.Any()).Return(&Ydb_Query.CommitTransactionResponse{ + Status: Ydb.StatusIds_SUCCESS, + }, nil) + + return &Ydb_Query.ExecuteQueryResponsePart{ + Status: Ydb.StatusIds_SUCCESS, + TxMeta: &Ydb_Query.TransactionMeta{ + Id: "456", }, - }, - Rows: []*Ydb.Value{ - { - Items: []*Ydb.Value{{ - Value: &Ydb.Value_Uint64Value{ - Uint64Value: 1, - }, - }, { - Value: &Ydb.Value_TextValue{ - TextValue: "1", + ResultSetIndex: 0, + ResultSet: &Ydb.ResultSet{ + Columns: []*Ydb.Column{ + { + Name: "a", + Type: &Ydb.Type{ + Type: &Ydb.Type_TypeId{ + TypeId: Ydb.Type_UINT64, + }, + }, }, - }}, - }, - { - Items: []*Ydb.Value{{ - Value: &Ydb.Value_Uint64Value{ - Uint64Value: 2, + { + Name: "b", + Type: &Ydb.Type{ + Type: &Ydb.Type_TypeId{ + TypeId: Ydb.Type_UTF8, + }, + }, }, - }, { - Value: &Ydb.Value_TextValue{ - TextValue: "2", + }, + Rows: []*Ydb.Value{ + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 1, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "1", + }, + }}, }, - }}, - }, - { - Items: []*Ydb.Value{{ - Value: &Ydb.Value_Uint64Value{ - Uint64Value: 3, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 2, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "2", + }, + }}, }, - }, { - Value: &Ydb.Value_TextValue{ - TextValue: "3", + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 3, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "3", + }, + }}, }, - }}, + }, }, - }, - }, - }, nil) - stream.EXPECT().Recv().Return(nil, io.EOF) - client.EXPECT().ExecuteQuery(gomock.Any(), gomock.Any()).Return(stream, nil) - client.EXPECT().CommitTransaction(gomock.Any(), gomock.Any()).Return(&Ydb_Query.CommitTransactionResponse{ - Status: Ydb.StatusIds_SUCCESS, - }, nil) - err := doTx(ctx, testPool(ctx, func(ctx context.Context) (*Session, error) { + }, nil + }) + stream.EXPECT().Recv().Return(nil, io.EOF) + client.EXPECT().ExecuteQuery(gomock.Any(), gomock.Any()).Return(stream, nil) + return newTestSessionWithClient("123", client), nil }), func(ctx context.Context, tx query.TxActor) error { defer func() { @@ -283,7 +287,7 @@ func TestClient(t *testing.T) { }) t.Run("TxLeak", func(t *testing.T) { t.Run("OnExec", func(t *testing.T) { - t.Run("WithExplicitCommit", func(t *testing.T) { + t.Run("WithoutCommit", func(t *testing.T) { xtest.TestManyTimes(t, func(t testing.TB) { txInFlight := 0 ctrl := gomock.NewController(t) @@ -297,12 +301,11 @@ func TestClient(t *testing.T) { return nil, xerrors.Operation(xerrors.WithStatusCode(Ydb.StatusIds_BAD_SESSION)) } + txInFlight++ + stream := NewMockQueryService_ExecuteQueryClient(ctrl) stream.EXPECT().Recv().DoAndReturn(func() (*Ydb_Query.ExecuteQueryResponsePart, error) { - txInFlight++ - stream.EXPECT().Recv().Return(nil, io.EOF) - client.EXPECT().CommitTransaction(gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, request *Ydb_Query.CommitTransactionRequest, option ...grpc.CallOption) ( *Ydb_Query.CommitTransactionResponse, error, @@ -334,7 +337,7 @@ func TestClient(t *testing.T) { require.Zero(t, txInFlight) }) }) - t.Run("WithLazyCommit", func(t *testing.T) { + t.Run("WithCommit", func(t *testing.T) { xtest.TestManyTimes(t, func(t testing.TB) { ctrl := gomock.NewController(t) txInFlight := 0 @@ -350,14 +353,16 @@ func TestClient(t *testing.T) { return nil, xerrors.Operation(xerrors.WithStatusCode(Ydb.StatusIds_BAD_SESSION)) } + txInFlight++ + stream := NewMockQueryService_ExecuteQueryClient(ctrl) stream.EXPECT().Recv().DoAndReturn(func() (*Ydb_Query.ExecuteQueryResponsePart, error) { if rand.Int31n(100) < 50 { + txInFlight-- + return nil, xerrors.Operation(xerrors.WithStatusCode(Ydb.StatusIds_BAD_SESSION)) } - txInFlight++ - stream.EXPECT().Recv().DoAndReturn(func() (*Ydb_Query.ExecuteQueryResponsePart, error) { txInFlight-- @@ -386,7 +391,7 @@ func TestClient(t *testing.T) { }) }) t.Run("OnSecondExec", func(t *testing.T) { - t.Run("WithExplicitCommit", func(t *testing.T) { + t.Run("WithoutCommit", func(t *testing.T) { xtest.TestManyTimes(t, func(t testing.TB) { ctrl := gomock.NewController(t) txInFlight := 0 @@ -400,10 +405,10 @@ func TestClient(t *testing.T) { return nil, xerrors.Operation(xerrors.WithStatusCode(Ydb.StatusIds_BAD_SESSION)) } + txInFlight++ + firstStream := NewMockQueryService_ExecuteQueryClient(ctrl) firstStream.EXPECT().Recv().DoAndReturn(func() (*Ydb_Query.ExecuteQueryResponsePart, error) { - txInFlight++ - firstStream.EXPECT().Recv().DoAndReturn(func() (*Ydb_Query.ExecuteQueryResponsePart, error) { client.EXPECT().ExecuteQuery(gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, request *Ydb_Query.ExecuteQueryRequest, option ...grpc.CallOption) ( @@ -476,7 +481,7 @@ func TestClient(t *testing.T) { require.NoError(t, err) }) }) - t.Run("WithLazyCommit", func(t *testing.T) { + t.Run("WithCommit", func(t *testing.T) { xtest.TestManyTimes(t, func(t testing.TB) { ctrl := gomock.NewController(t) txInFlight := 0 @@ -490,10 +495,10 @@ func TestClient(t *testing.T) { return nil, xerrors.Operation(xerrors.WithStatusCode(Ydb.StatusIds_BAD_SESSION)) } + txInFlight++ + firstStream := NewMockQueryService_ExecuteQueryClient(ctrl) firstStream.EXPECT().Recv().DoAndReturn(func() (*Ydb_Query.ExecuteQueryResponsePart, error) { - txInFlight++ - firstStream.EXPECT().Recv().DoAndReturn(func() (*Ydb_Query.ExecuteQueryResponsePart, error) { client.EXPECT().ExecuteQuery(gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, request *Ydb_Query.ExecuteQueryRequest, option ...grpc.CallOption) ( diff --git a/internal/query/errors.go b/internal/query/errors.go index 88c5f01e6..2f78e9b4a 100644 --- a/internal/query/errors.go +++ b/internal/query/errors.go @@ -7,14 +7,13 @@ import ( ) var ( - ErrTransactionRollingBack = xerrors.Wrap(errors.New("ydb: the transaction is rolling back")) - errWrongNextResultSetIndex = errors.New("wrong result set index") - errWrongResultSetIndex = errors.New("critical violation of the logic - wrong result set index") - errMoreThanOneRow = errors.New("unexpected more than one row in result set") - errMoreThanOneResultSet = errors.New("unexpected more than one result set") - errNoResultSets = errors.New("no result sets") - errUnexpectedTxIDOnCommitFlag = errors.New("unexpected transaction ID on commit flag") - errExpectedTxID = errors.New("expected transaction ID but nil") - ErrOptionNotForTxExecute = errors.New("option is not for execute on transaction") - errExecuteOnCompletedTx = errors.New("execute on completed transaction") + ErrTransactionRollingBack = xerrors.Wrap(errors.New("ydb: the transaction is rolling back")) + errWrongNextResultSetIndex = errors.New("wrong result set index") + errWrongResultSetIndex = errors.New("critical violation of the logic - wrong result set index") + errMoreThanOneRow = errors.New("unexpected more than one row in result set") + errMoreThanOneResultSet = errors.New("unexpected more than one result set") + errNoResultSets = errors.New("no result sets") + errNilOption = errors.New("nil option") + ErrOptionNotForTxExecute = errors.New("option is not for execute on transaction") + errExecuteOnCompletedTx = errors.New("execute on completed transaction") ) diff --git a/internal/query/execute_query.go b/internal/query/execute_query.go index 7b389f71e..9d535c4dc 100644 --- a/internal/query/execute_query.go +++ b/internal/query/execute_query.go @@ -14,7 +14,6 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/internal/allocator" "github.com/ydb-platform/ydb-go-sdk/v3/internal/params" "github.com/ydb-platform/ydb-go-sdk/v3/internal/query/options" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/tx" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" "github.com/ydb-platform/ydb-go-sdk/v3/query" @@ -100,7 +99,7 @@ func execute( ctx context.Context, sessionID string, c Ydb_Query_V1.QueryServiceClient, q string, settings executeSettings, opts ...resultOption, ) ( - _ tx.Identifier, _ *streamResult, finalErr error, + _ *streamResult, finalErr error, ) { a := allocator.New() defer a.Free() @@ -111,19 +110,15 @@ func execute( stream, err := c.ExecuteQuery(executeCtx, request, callOptions...) if err != nil { - return nil, nil, xerrors.WithStackTrace(err) + return nil, xerrors.WithStackTrace(err) } - r, txID, err := newResult(ctx, stream, append(opts, withStatsCallback(settings.StatsCallback()))...) + r, err := newResult(ctx, stream, append(opts, withStatsCallback(settings.StatsCallback()))...) if err != nil { - return nil, nil, xerrors.WithStackTrace(err) - } - - if txID == "" { - return nil, r, nil + return nil, xerrors.WithStackTrace(err) } - return tx.ID(txID), r, nil + return r, nil } func readAll(ctx context.Context, r *streamResult) error { diff --git a/internal/query/execute_query_test.go b/internal/query/execute_query_test.go index 07ea99b23..5698ded92 100644 --- a/internal/query/execute_query_test.go +++ b/internal/query/execute_query_test.go @@ -356,10 +356,15 @@ func TestExecute(t *testing.T) { stream.EXPECT().Recv().Return(nil, io.EOF) client := NewMockQueryServiceClient(ctrl) client.EXPECT().ExecuteQuery(gomock.Any(), gomock.Any()).Return(stream, nil) - tx, r, err := execute(ctx, "123", client, "", options.ExecuteSettings()) + var txID string + r, err := execute(ctx, "123", client, "", options.ExecuteSettings(), + onTxMeta(func(txMeta *Ydb_Query.TransactionMeta) { + txID = txMeta.GetId() + }), + ) require.NoError(t, err) defer r.Close(ctx) - require.EqualValues(t, "456", tx.ID()) + require.EqualValues(t, "456", txID) require.EqualValues(t, -1, r.resultSetIndex) { t.Log("nextResultSet") @@ -466,7 +471,7 @@ func TestExecute(t *testing.T) { client := NewMockQueryServiceClient(ctrl) client.EXPECT().ExecuteQuery(gomock.Any(), gomock.Any()).Return(nil, grpcStatus.Error(grpcCodes.Unavailable, "")) t.Log("execute") - _, _, err := execute(ctx, "123", client, "", options.ExecuteSettings()) + _, err := execute(ctx, "123", client, "", options.ExecuteSettings()) require.Error(t, err) require.True(t, xerrors.IsTransportError(err, grpcCodes.Unavailable)) }) @@ -570,10 +575,15 @@ func TestExecute(t *testing.T) { client := NewMockQueryServiceClient(ctrl) client.EXPECT().ExecuteQuery(gomock.Any(), gomock.Any()).Return(stream, nil) t.Log("execute") - tx, r, err := execute(ctx, "123", client, "", options.ExecuteSettings()) + var txID string + r, err := execute(ctx, "123", client, "", options.ExecuteSettings(), + onTxMeta(func(txMeta *Ydb_Query.TransactionMeta) { + txID = txMeta.GetId() + }), + ) require.NoError(t, err) defer r.Close(ctx) - require.EqualValues(t, "456", tx.ID()) + require.EqualValues(t, "456", txID) require.EqualValues(t, -1, r.resultSetIndex) { t.Log("nextResultSet") @@ -630,7 +640,7 @@ func TestExecute(t *testing.T) { client := NewMockQueryServiceClient(ctrl) client.EXPECT().ExecuteQuery(gomock.Any(), gomock.Any()).Return(stream, nil) t.Log("execute") - _, _, err := execute(ctx, "123", client, "", options.ExecuteSettings()) + _, err := execute(ctx, "123", client, "", options.ExecuteSettings()) require.Error(t, err) require.True(t, xerrors.IsOperationError(err, Ydb.StatusIds_UNAVAILABLE)) }) @@ -706,10 +716,15 @@ func TestExecute(t *testing.T) { client := NewMockQueryServiceClient(ctrl) client.EXPECT().ExecuteQuery(gomock.Any(), gomock.Any()).Return(stream, nil) t.Log("execute") - tx, r, err := execute(ctx, "123", client, "", options.ExecuteSettings()) + var txID string + r, err := execute(ctx, "123", client, "", options.ExecuteSettings(), + onTxMeta(func(txMeta *Ydb_Query.TransactionMeta) { + txID = txMeta.GetId() + }), + ) require.NoError(t, err) defer r.Close(ctx) - require.EqualValues(t, "456", tx.ID()) + require.EqualValues(t, "456", txID) require.EqualValues(t, -1, r.resultSetIndex) { t.Log("nextResultSet") diff --git a/internal/query/result.go b/internal/query/result.go index 079d997f6..3f29e1f7b 100644 --- a/internal/query/result.go +++ b/internal/query/result.go @@ -38,6 +38,7 @@ type ( trace *trace.Query statsCallback func(queryStats stats.QueryStats) onNextPartErr []func(err error) + onTxMeta []func(txMeta *Ydb_Query.TransactionMeta) } resultOption func(s *streamResult) ) @@ -101,11 +102,17 @@ func onNextPartErr(callback func(err error)) resultOption { } } +func onTxMeta(callback func(txMeta *Ydb_Query.TransactionMeta)) resultOption { + return func(s *streamResult) { + s.onTxMeta = append(s.onTxMeta, callback) + } +} + func newResult( ctx context.Context, stream Ydb_Query_V1.QueryService_ExecuteQueryClient, opts ...resultOption, -) (_ *streamResult, txID string, finalErr error) { +) (_ *streamResult, finalErr error) { r := streamResult{ stream: stream, closed: make(chan struct{}), @@ -133,11 +140,11 @@ func newResult( select { case <-ctx.Done(): - return nil, txID, xerrors.WithStackTrace(ctx.Err()) + return nil, xerrors.WithStackTrace(ctx.Err()) default: part, err := r.nextPart(ctx) if err != nil { - return nil, txID, xerrors.WithStackTrace(err) + return nil, xerrors.WithStackTrace(err) } r.lastPart = part @@ -146,7 +153,7 @@ func newResult( r.statsCallback(stats.FromQueryStats(part.GetExecStats())) } - return &r, part.GetTxMeta().GetId(), nil + return &r, nil } } @@ -177,6 +184,12 @@ func (r *streamResult) nextPart(ctx context.Context) ( return nil, xerrors.WithStackTrace(err) } + if txMeta := part.GetTxMeta(); txMeta != nil { + for _, f := range r.onTxMeta { + f(txMeta) + } + } + return part, nil } } diff --git a/internal/query/result_go1.23_test.go b/internal/query/result_go1.23_test.go index 945c7f349..a45703088 100644 --- a/internal/query/result_go1.23_test.go +++ b/internal/query/result_go1.23_test.go @@ -344,7 +344,7 @@ func TestResultRangeResultSets(t *testing.T) { }, }, nil) stream.EXPECT().Recv().Return(nil, io.EOF) - r, _, err := newResult(ctx, stream, nil) + r, err := newResult(ctx, stream, nil) require.NoError(t, err) defer r.Close(ctx) rsCount := 0 diff --git a/internal/query/result_test.go b/internal/query/result_test.go index 4be2ac98e..faa18d2d5 100644 --- a/internal/query/result_test.go +++ b/internal/query/result_test.go @@ -351,7 +351,7 @@ func TestResultNextResultSet(t *testing.T) { }, }, nil) stream.EXPECT().Recv().Return(nil, io.EOF) - r, _, err := newResult(ctx, stream, nil) + r, err := newResult(ctx, stream, nil) require.NoError(t, err) defer r.Close(ctx) { @@ -518,7 +518,7 @@ func TestResultNextResultSet(t *testing.T) { }, }, }, nil) - r, _, err := newResult(ctx, stream, nil) + r, err := newResult(ctx, stream, nil) require.NoError(t, err) defer r.Close(ctx) { @@ -833,7 +833,7 @@ func TestResultNextResultSet(t *testing.T) { }, }, }, nil) - r, _, err := newResult(ctx, stream, nil) + r, err := newResult(ctx, stream, nil) require.NoError(t, err) defer r.Close(ctx) { @@ -937,7 +937,7 @@ func TestExactlyOneRowFromResult(t *testing.T) { }, }, nil) stream.EXPECT().Recv().Return(nil, io.EOF) - r, _, err := newResult(ctx, stream, nil) + r, err := newResult(ctx, stream, nil) require.NoError(t, err) row, err := exactlyOneRowFromResult(ctx, r) @@ -1005,7 +1005,7 @@ func TestExactlyOneRowFromResult(t *testing.T) { }, }, }, nil) - r, _, err := newResult(ctx, stream, nil) + r, err := newResult(ctx, stream, nil) require.NoError(t, err) row, err := exactlyOneRowFromResult(ctx, r) @@ -1057,7 +1057,7 @@ func TestExactlyOneRowFromResult(t *testing.T) { }, nil) testErr := errors.New("test-err") stream.EXPECT().Recv().Return(nil, testErr) - r, _, err := newResult(ctx, stream, nil) + r, err := newResult(ctx, stream, nil) require.NoError(t, err) row, err := exactlyOneRowFromResult(ctx, r) @@ -1147,7 +1147,7 @@ func TestExactlyOneRowFromResult(t *testing.T) { }, }, }, nil) - r, _, err := newResult(ctx, stream, nil) + r, err := newResult(ctx, stream, nil) require.NoError(t, err) row, err := exactlyOneResultSetFromResult(ctx, r) @@ -1199,7 +1199,7 @@ func TestExactlyOneRowFromResult(t *testing.T) { }, nil) testErr := errors.New("test-err") stream.EXPECT().Recv().Return(nil, testErr) - r, _, err := newResult(ctx, stream, nil) + r, err := newResult(ctx, stream, nil) require.NoError(t, err) row, err := exactlyOneRowFromResult(ctx, r) @@ -1265,7 +1265,7 @@ func TestExactlyOneResultSetFromResult(t *testing.T) { }, }, nil) stream.EXPECT().Recv().Return(nil, io.EOF) - r, _, err := newResult(ctx, stream, nil) + r, err := newResult(ctx, stream, nil) require.NoError(t, err) rs, err := exactlyOneResultSetFromResult(ctx, r) @@ -1375,7 +1375,7 @@ func TestExactlyOneResultSetFromResult(t *testing.T) { }, }, }, nil) - r, _, err := newResult(ctx, stream, nil) + r, err := newResult(ctx, stream, nil) require.NoError(t, err) rs, err := exactlyOneResultSetFromResult(ctx, r) @@ -1427,7 +1427,7 @@ func TestExactlyOneResultSetFromResult(t *testing.T) { }, nil) testErr := errors.New("test-err") stream.EXPECT().Recv().Return(nil, testErr) - r, _, err := newResult(ctx, stream, nil) + r, err := newResult(ctx, stream, nil) require.NoError(t, err) rs, err := exactlyOneResultSetFromResult(ctx, r) @@ -1544,7 +1544,7 @@ func TestCloseResultOnCloseClosableResultSet(t *testing.T) { }, nil) stream.EXPECT().Recv().Return(nil, io.EOF) var closed bool - r, _, err := newResult(ctx, stream, withTrace(&trace.Query{ + r, err := newResult(ctx, stream, withTrace(&trace.Query{ OnResultClose: func(info trace.QueryResultCloseStartInfo) func(info trace.QueryResultCloseDoneInfo) { require.False(t, closed) closed = true @@ -1915,7 +1915,7 @@ func TestResultStats(t *testing.T) { }, nil) stream.EXPECT().Recv().Return(nil, io.EOF) var s stats.QueryStats - result, _, err := newResult(ctx, stream, withStatsCallback(func(queryStats stats.QueryStats) { + result, err := newResult(ctx, stream, withStatsCallback(func(queryStats stats.QueryStats) { s = queryStats })) require.NoError(t, err) @@ -2276,7 +2276,7 @@ func TestResultStats(t *testing.T) { }, nil) stream.EXPECT().Recv().Return(nil, io.EOF) var s stats.QueryStats - result, _, err := newResult(ctx, stream, withStatsCallback(func(queryStats stats.QueryStats) { + result, err := newResult(ctx, stream, withStatsCallback(func(queryStats stats.QueryStats) { s = queryStats })) require.NoError(t, err) @@ -2638,7 +2638,7 @@ func TestResultStats(t *testing.T) { }, nil) stream.EXPECT().Recv().Return(nil, io.EOF) var s stats.QueryStats - result, _, err := newResult(ctx, stream, withStatsCallback(func(queryStats stats.QueryStats) { + result, err := newResult(ctx, stream, withStatsCallback(func(queryStats stats.QueryStats) { s = queryStats })) require.NoError(t, err) @@ -2975,7 +2975,7 @@ func TestResultStats(t *testing.T) { }, nil) stream.EXPECT().Recv().Return(nil, io.EOF) var s stats.QueryStats - result, _, err := newResult(ctx, stream, withStatsCallback(func(queryStats stats.QueryStats) { + result, err := newResult(ctx, stream, withStatsCallback(func(queryStats stats.QueryStats) { s = queryStats })) require.NoError(t, err) @@ -3009,7 +3009,7 @@ func TestMaterializedResultStats(t *testing.T) { stream Ydb_Query_V1.QueryService_ExecuteQueryClient, opts ...resultOption, ) (query.Result, error) { - r, _, err := newResult(ctx, stream, opts...) + r, err := newResult(ctx, stream, opts...) if err != nil { return nil, err } diff --git a/internal/query/session.go b/internal/query/session.go index bba66c3b7..d094ef87d 100644 --- a/internal/query/session.go +++ b/internal/query/session.go @@ -35,7 +35,7 @@ func (s *Session) QueryResultSet( onDone(finalErr) }() - _, r, err := execute(ctx, s.ID(), s.client, q, options.ExecuteSettings(opts...), withTrace(s.trace)) + r, err := execute(ctx, s.ID(), s.client, q, options.ExecuteSettings(opts...), withTrace(s.trace)) if err != nil { return nil, xerrors.WithStackTrace(err) } @@ -51,7 +51,7 @@ func (s *Session) QueryResultSet( func (s *Session) queryRow( ctx context.Context, q string, settings executeSettings, resultOpts ...resultOption, ) (row query.Row, finalErr error) { - _, r, err := execute(ctx, s.ID(), s.client, q, settings, resultOpts...) + r, err := execute(ctx, s.ID(), s.client, q, settings, resultOpts...) if err != nil { return nil, xerrors.WithStackTrace(err) } @@ -121,7 +121,7 @@ func (s *Session) Exec( onDone(finalErr) }() - _, r, err := execute(ctx, s.ID(), s.client, q, options.ExecuteSettings(opts...), withTrace(s.trace)) + r, err := execute(ctx, s.ID(), s.client, q, options.ExecuteSettings(opts...), withTrace(s.trace)) if err != nil { return xerrors.WithStackTrace(err) } @@ -143,7 +143,7 @@ func (s *Session) Query( onDone(finalErr) }() - _, r, err := execute(ctx, s.ID(), s.client, q, options.ExecuteSettings(opts...), withTrace(s.trace)) + r, err := execute(ctx, s.ID(), s.client, q, options.ExecuteSettings(opts...), withTrace(s.trace)) if err != nil { return nil, xerrors.WithStackTrace(err) } diff --git a/internal/query/session/session.go b/internal/query/session/session.go index 54612b5d7..69ca4aa58 100644 --- a/internal/query/session/session.go +++ b/internal/query/session/session.go @@ -2,7 +2,7 @@ package session import ( "context" - "fmt" + "sync/atomic" "time" "github.com/ydb-platform/ydb-go-genproto/Ydb_Query_V1" @@ -35,7 +35,7 @@ type ( deleteTimeout time.Duration id string nodeID uint32 - status Status + status atomic.Uint32 closeOnce func(ctx context.Context) error checks []func(s *core) bool } @@ -50,25 +50,15 @@ func (c *core) NodeID() uint32 { } func (c *core) statusCode() Status { - return c.status + return Status(c.status.Load()) } func (c *core) SetStatus(status Status) { - switch c.status { - case statusUnknown: - c.status = status - case StatusIdle: - c.status = status - case StatusInUse: - c.status = status - case StatusClosing: - c.status = status - case StatusClosed: - c.status = status - case StatusError: - c.status = status + switch Status(c.status.Load()) { + case StatusClosed, StatusError: + // nop default: - panic(fmt.Sprintf("Unknown%d", c.status)) + c.status.Store(uint32(status)) } } @@ -111,10 +101,9 @@ func Open( core := &core{ Client: client, Trace: &trace.Query{}, - status: statusUnknown, checks: []func(s *core) bool{ func(s *core) bool { - return IsAlive(s.status) + return IsAlive(Status(s.status.Load())) }, }, } @@ -199,11 +188,10 @@ func (c *core) attach(ctx context.Context) (finalErr error) { _ = c.closeOnce(xcontext.ValueOnly(ctx)) }() - for func() bool { - _, recvErr := attach.Recv() - - return recvErr == nil - }() { + for c.IsAlive() { + if _, recvErr := attach.Recv(); recvErr != nil { + return + } } }() diff --git a/internal/query/transaction.go b/internal/query/transaction.go index 4b43a5288..02f6d7fd7 100644 --- a/internal/query/transaction.go +++ b/internal/query/transaction.go @@ -26,13 +26,9 @@ var ( _ baseTx.Transaction = (*Transaction)(nil) ) -const ( - LazyTxID = "LAZY_TX" -) - type ( Transaction struct { - baseTx.Identifier + baseTx.LazyID s *Session txSettings query.TransactionSettings @@ -46,34 +42,36 @@ type ( func begin( ctx context.Context, client Ydb_Query_V1.QueryServiceClient, - s *Session, + sessionID string, txSettings query.TransactionSettings, -) (baseTx.Identifier, error) { +) (txID string, _ error) { a := allocator.New() defer a.Free() response, err := client.BeginTransaction(ctx, &Ydb_Query.BeginTransactionRequest{ - SessionId: s.ID(), + SessionId: sessionID, TxSettings: txSettings.ToYDB(a), }, ) if err != nil { - return nil, xerrors.WithStackTrace(err) + return "", xerrors.WithStackTrace(err) } - return baseTx.NewID(response.GetTxMeta().GetId()), nil + return response.GetTxMeta().GetId(), nil } -func (tx *Transaction) UnLazy(ctx context.Context) (err error) { - if tx.Identifier != nil { +func (tx *Transaction) UnLazy(ctx context.Context) error { + if tx.ID() != baseTx.LazyTxID { return nil } - tx.Identifier, err = begin(ctx, tx.s.client, tx.s, tx.txSettings) + txID, err := begin(ctx, tx.s.client, tx.s.ID(), tx.txSettings) if err != nil { return xerrors.WithStackTrace(err) } + tx.SetTxID(txID) + return nil } @@ -97,6 +95,9 @@ func (tx *Transaction) QueryResultSet( resultOpts := []resultOption{ withTrace(tx.s.trace), + onTxMeta(func(txMeta *Ydb_Query.TransactionMeta) { + tx.SetTxID(txMeta.GetId()) + }), } if settings.TxControl().Commit { // notification about complete transaction must be sended for any error or for successfully read all result if @@ -107,23 +108,11 @@ func (tx *Transaction) QueryResultSet( }), ) } - txID, r, err := execute(ctx, tx.s.ID(), tx.s.client, q, settings, resultOpts...) + r, err := execute(ctx, tx.s.ID(), tx.s.client, q, settings, resultOpts...) if err != nil { return nil, xerrors.WithStackTrace(err) } - if settings.TxControl().Commit { - if txID != nil && tx.Identifier != nil { - return nil, xerrors.WithStackTrace(errUnexpectedTxIDOnCommitFlag) - } - tx.completed = true - } else if tx.Identifier == nil { - if txID == nil { - return nil, xerrors.WithStackTrace(errExpectedTxID) - } - tx.Identifier = txID - } - rs, err = readResultSet(ctx, r) if err != nil { return nil, xerrors.WithStackTrace(err) @@ -150,6 +139,9 @@ func (tx *Transaction) QueryRow( resultOpts := []resultOption{ withTrace(tx.s.trace), + onTxMeta(func(txMeta *Ydb_Query.TransactionMeta) { + tx.SetTxID(txMeta.GetId()) + }), } if settings.TxControl().Commit { // notification about complete transaction must be sended for any error or for successfully read all result if @@ -160,15 +152,11 @@ func (tx *Transaction) QueryRow( }), ) } - txID, r, err := execute(ctx, tx.s.ID(), tx.s.client, q, settings, resultOpts...) + r, err := execute(ctx, tx.s.ID(), tx.s.client, q, settings, resultOpts...) if err != nil { return nil, xerrors.WithStackTrace(err) } - if tx.Identifier == nil { - tx.Identifier = txID - } - row, err = readRow(ctx, r) if err != nil { return nil, xerrors.WithStackTrace(err) @@ -182,8 +170,8 @@ func (tx *Transaction) SessionID() string { } func (tx *Transaction) txControl() *queryTx.Control { - if tx.Identifier != nil { - return queryTx.NewControl(queryTx.WithTxID(tx.Identifier.ID())) + if tx.ID() != baseTx.LazyTxID { + return queryTx.NewControl(queryTx.WithTxID(tx.ID())) } return queryTx.NewControl( @@ -191,14 +179,6 @@ func (tx *Transaction) txControl() *queryTx.Control { ) } -func (tx *Transaction) ID() string { - if tx.Identifier == nil { - return LazyTxID - } - - return tx.Identifier.ID() -} - func (tx *Transaction) Exec(ctx context.Context, q string, opts ...options.Execute) ( finalErr error, ) { @@ -219,6 +199,9 @@ func (tx *Transaction) Exec(ctx context.Context, q string, opts ...options.Execu resultOpts := []resultOption{ withTrace(tx.s.trace), + onTxMeta(func(txMeta *Ydb_Query.TransactionMeta) { + tx.SetTxID(txMeta.GetId()) + }), } if settings.TxControl().Commit { // notification about complete transaction must be sended for any error or for successfully read all result if @@ -230,23 +213,11 @@ func (tx *Transaction) Exec(ctx context.Context, q string, opts ...options.Execu ) } - txID, r, err := execute(ctx, tx.s.ID(), tx.s.client, q, settings, resultOpts...) + r, err := execute(ctx, tx.s.ID(), tx.s.client, q, settings, resultOpts...) if err != nil { return xerrors.WithStackTrace(err) } - if settings.TxControl().Commit { - if txID != nil && tx.Identifier != nil { - return xerrors.WithStackTrace(errUnexpectedTxIDOnCommitFlag) - } - tx.completed = true - } else if tx.Identifier == nil { - if txID == nil { - return xerrors.WithStackTrace(errExpectedTxID) - } - tx.Identifier = txID - } - err = readAll(ctx, r) if err != nil { return xerrors.WithStackTrace(err) @@ -255,10 +226,10 @@ func (tx *Transaction) Exec(ctx context.Context, q string, opts ...options.Execu return nil } -func (tx *Transaction) executeSettings(opts ...options.Execute) (_ executeSettings, _ error) { +func (tx *Transaction) executeSettings(opts ...options.Execute) (_ executeSettings, finalErr error) { for _, opt := range opts { if opt == nil { - return nil, xerrors.WithStackTrace(errExpectedTxID) + return nil, xerrors.WithStackTrace(errNilOption) } if _, has := opt.(options.ExecuteNoTx); has { return nil, xerrors.WithStackTrace( @@ -267,22 +238,8 @@ func (tx *Transaction) executeSettings(opts ...options.Execute) (_ executeSettin } } - if tx.Identifier != nil { - return options.ExecuteSettings(append([]options.Execute{ - options.WithTxControl( - queryTx.NewControl( - queryTx.WithTxID(tx.Identifier.ID()), - ), - ), - }, opts...)...), nil - } - return options.ExecuteSettings(append([]options.Execute{ - options.WithTxControl( - queryTx.NewControl( - queryTx.BeginTx(tx.txSettings...), - ), - ), + options.WithTxControl(tx.txControl()), }, opts...)...), nil } @@ -306,6 +263,9 @@ func (tx *Transaction) Query(ctx context.Context, q string, opts ...options.Exec resultOpts := []resultOption{ withTrace(tx.s.trace), + onTxMeta(func(txMeta *Ydb_Query.TransactionMeta) { + tx.SetTxID(txMeta.GetId()) + }), } if settings.TxControl().Commit { // notification about complete transaction must be sended for any error or for successfully read all result if @@ -316,25 +276,11 @@ func (tx *Transaction) Query(ctx context.Context, q string, opts ...options.Exec }), ) } - txID, r, err := execute(ctx, tx.s.ID(), tx.s.client, q, settings, resultOpts...) + r, err := execute(ctx, tx.s.ID(), tx.s.client, q, settings, resultOpts...) if err != nil { return nil, xerrors.WithStackTrace(err) } - if settings.TxControl().Commit { - if txID != nil && tx.Identifier != nil { - return nil, xerrors.WithStackTrace(errUnexpectedTxIDOnCommitFlag) - } - tx.completed = true - - return r, nil - } else if tx.Identifier == nil { - if txID == nil { - return nil, xerrors.WithStackTrace(errExpectedTxID) - } - tx.Identifier = txID - } - return r, nil } @@ -351,7 +297,7 @@ func commitTx(ctx context.Context, client Ydb_Query_V1.QueryServiceClient, sessi } func (tx *Transaction) CommitTx(ctx context.Context) (finalErr error) { - if tx.Identifier == nil { + if tx.ID() == baseTx.LazyTxID { return nil } @@ -389,8 +335,9 @@ func rollback(ctx context.Context, client Ydb_Query_V1.QueryServiceClient, sessi } func (tx *Transaction) Rollback(ctx context.Context) (finalErr error) { - if tx.Identifier == nil { - return nil + if tx.ID() == baseTx.LazyTxID { + // https://github.com/ydb-platform/ydb-go-sdk/issues/1456 + return tx.s.Close(ctx) } if tx.completed { @@ -418,6 +365,8 @@ func (tx *Transaction) OnCompleted(f baseTx.OnTransactionCompletedFunc) { } func (tx *Transaction) notifyOnCompleted(err error) { + tx.completed = true + tx.onCompleted.Range(func(f *baseTx.OnTransactionCompletedFunc) bool { (*f)(err) diff --git a/internal/query/transaction_fixtures_test.go b/internal/query/transaction_fixtures_test.go index ef3cbb8e2..78d6685d7 100644 --- a/internal/query/transaction_fixtures_test.go +++ b/internal/query/transaction_fixtures_test.go @@ -11,8 +11,12 @@ import ( func TransactionOverGrpcMock(e fixenv.Env) *Transaction { f := func() (*fixenv.GenericResult[*Transaction], error) { return fixenv.NewGenericResult(&Transaction{ - Identifier: tx.ID(fmt.Sprintf("test-transaction-id-%v", e.T().Name())), - s: SessionOverGrpcMock(e), + LazyID: func() (id tx.LazyID) { + id.SetTxID(fmt.Sprintf("test-transaction-id-%v", e.T().Name())) + + return id + }(), + s: SessionOverGrpcMock(e), }), nil } diff --git a/internal/query/transaction_test.go b/internal/query/transaction_test.go index 3cc8f3a23..30e2b2643 100644 --- a/internal/query/transaction_test.go +++ b/internal/query/transaction_test.go @@ -23,7 +23,7 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/internal/params" "github.com/ydb-platform/ydb-go-sdk/v3/internal/query/options" internal "github.com/ydb-platform/ydb-go-sdk/v3/internal/query/tx" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/tx" + baseTx "github.com/ydb-platform/ydb-go-sdk/v3/internal/tx" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xtest" "github.com/ydb-platform/ydb-go-sdk/v3/query" @@ -31,7 +31,7 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/table/stats" ) -var _ tx.Transaction = &Transaction{} +var _ baseTx.Transaction = &Transaction{} func TestBegin(t *testing.T) { t.Run("HappyWay", func(t *testing.T) { @@ -45,9 +45,9 @@ func TestBegin(t *testing.T) { }, }, nil) t.Log("begin") - tx, err := begin(ctx, client, &Session{Core: &sessionControllerMock{id: "123"}}, query.TxSettings()) + txID, err := begin(ctx, client, "123", query.TxSettings()) require.NoError(t, err) - require.Equal(t, "123", tx.ID()) + require.Equal(t, "123", txID) }) t.Run("TransportError", func(t *testing.T) { ctx := xtest.Context(t) @@ -55,7 +55,7 @@ func TestBegin(t *testing.T) { client := NewMockQueryServiceClient(ctrl) client.EXPECT().BeginTransaction(gomock.Any(), gomock.Any()).Return(nil, grpcStatus.Error(grpcCodes.Unavailable, "")) t.Log("begin") - _, err := begin(ctx, client, &Session{Core: &sessionControllerMock{id: "123"}}, query.TxSettings()) + _, err := begin(ctx, client, "123", query.TxSettings()) require.Error(t, err) require.True(t, xerrors.IsTransportError(err, grpcCodes.Unavailable)) }) @@ -67,7 +67,7 @@ func TestBegin(t *testing.T) { xerrors.Operation(xerrors.WithStatusCode(Ydb.StatusIds_UNAVAILABLE)), ) t.Log("begin") - _, err := begin(ctx, client, &Session{Core: &sessionControllerMock{id: "123"}}, query.TxSettings()) + _, err := begin(ctx, client, "123", query.TxSettings()) require.Error(t, err) require.True(t, xerrors.IsOperationError(err, Ydb.StatusIds_UNAVAILABLE)) }) diff --git a/internal/tx/id.go b/internal/tx/id.go index 83a8c856f..378d100ae 100644 --- a/internal/tx/id.go +++ b/internal/tx/id.go @@ -1,16 +1,38 @@ package tx -var _ Identifier = (*ID)(nil) +var ( + _ Identifier = (*ID)(nil) + _ Identifier = (*LazyID)(nil) +) + +const ( + LazyTxID = "LAZY_TX" +) type ( Identifier interface { ID() string isYdbTx() } - ID string + ID string + LazyID struct { + v *string + } ) -var Lazy = ID("") +func (id *LazyID) ID() string { + if id.v == nil { + return LazyTxID + } + + return *id.v +} + +func (id *LazyID) SetTxID(txID string) { + id.v = &txID +} + +func (id *LazyID) isYdbTx() {} func NewID(id string) ID { return ID(id) diff --git a/tests/integration/query_execute_test.go b/tests/integration/query_execute_test.go index 51308e005..dd7583240 100644 --- a/tests/integration/query_execute_test.go +++ b/tests/integration/query_execute_test.go @@ -6,11 +6,17 @@ package integration import ( "context" "encoding/json" + "errors" "fmt" + "io" + "math/rand" "os" + "path" + "sync" "testing" "time" + "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/ydb-platform/ydb-go-sdk/v3" @@ -18,6 +24,7 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/internal/xtest" "github.com/ydb-platform/ydb-go-sdk/v3/log" "github.com/ydb-platform/ydb-go-sdk/v3/query" + "github.com/ydb-platform/ydb-go-sdk/v3/table/types" "github.com/ydb-platform/ydb-go-sdk/v3/trace" ) @@ -287,3 +294,118 @@ func TestQueryExecute(t *testing.T) { require.NoError(t, err) }) } + +// https://github.com/ydb-platform/ydb-go-sdk/issues/1456 +func TestIssue1456TooManyUnknownTransactions(t *testing.T) { + if version.Lt(os.Getenv("YDB_VERSION"), "24.1") { + t.Skip("query service not allowed in YDB version '" + os.Getenv("YDB_VERSION") + "'") + } + + ctx, cancel := context.WithCancel(xtest.Context(t)) + defer cancel() + + db, err := ydb.Open(ctx, + os.Getenv("YDB_CONNECTION_STRING"), + ydb.WithAccessTokenCredentials(os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS")), + ) + require.NoError(t, err) + + const ( + tableSize = 10000 + queries = 1000 + chSize = 50 + ) + + tableName := path.Join(db.Name(), t.Name(), "test") + + err = db.Query().Exec(ctx, "DROP TABLE IF EXISTS `"+tableName+"`;") + require.NoError(t, err) + + err = db.Query().Exec(ctx, `CREATE TABLE `+"`"+tableName+"`"+` ( + id Utf8, + value Uint64, + PRIMARY KEY(id) + )`, + ) + require.NoError(t, err) + + var vals []types.Value + for i := 0; i < tableSize; i++ { + vals = append(vals, types.StructValue( + types.StructFieldValue("id", types.UTF8Value(uuid.NewString())), + types.StructFieldValue("value", types.Uint64Value(rand.Uint64())), + )) + } + err = db.Query().Do(context.Background(), func(ctx context.Context, s query.Session) error { + return s.Exec(ctx, ` + PRAGMA AnsiInForEmptyOrNullableItemsCollections; + DECLARE $vals AS List>; + + INSERT INTO `+"`"+tableName+"`"+` + SELECT id, value FROM AS_TABLE($vals);`, + query.WithParameters( + ydb.ParamsBuilder(). + Param("$vals").BeginList().AddItems(vals...).EndList().Build(), + ), + ) + }) + require.NoError(t, err) + + t.Run("Query", func(t *testing.T) { + wg := sync.WaitGroup{} + wg.Add(queries) + ch := make(chan struct{}, chSize) + for i := 0; i < queries; i++ { + ch <- struct{}{} + go func() { + defer func() { <-ch }() + defer wg.Done() + + err := db.Query().DoTx(ctx, func(ctx context.Context, tx query.TxActor) error { + var ( + id string + v uint64 + ) + + res, err := tx.Query(ctx, `SELECT id, value FROM `+"`"+tableName+"`") + if err != nil { + return err + } + + for { + set, err := res.NextResultSet(ctx) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + + return err + } + + for { + row, err := set.NextRow(ctx) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + + return err + } + + err = row.Scan(&id, &v) + if err != nil { + return err + } + } + } + return res.Close(ctx) + }, query.WithTxSettings(query.TxSettings(query.WithSerializableReadWrite()))) + require.NoError(t, err) + }() + } + wg.Wait() + }) +} diff --git a/tests/integration/query_tx_execute_test.go b/tests/integration/query_tx_execute_test.go index 0552c4aec..fd3b110b0 100644 --- a/tests/integration/query_tx_execute_test.go +++ b/tests/integration/query_tx_execute_test.go @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/require" internalQuery "github.com/ydb-platform/ydb-go-sdk/v3/internal/query" + baseTx "github.com/ydb-platform/ydb-go-sdk/v3/internal/tx" "github.com/ydb-platform/ydb-go-sdk/v3/internal/version" "github.com/ydb-platform/ydb-go-sdk/v3/query" ) @@ -30,14 +31,14 @@ func TestQueryTxExecute(t *testing.T) { ) t.Run("Default", func(t *testing.T) { err := scope.DriverWithLogs().Query().DoTx(scope.Ctx, func(ctx context.Context, tx query.TxActor) (err error) { - if tx.ID() != internalQuery.LazyTxID { + if tx.ID() != baseTx.LazyTxID { return errors.New("transaction is not lazy") } res, err := tx.Query(ctx, "SELECT 1 AS col1") if err != nil { return err } - if tx.ID() == internalQuery.LazyTxID { + if tx.ID() == baseTx.LazyTxID { return errors.New("transaction is lazy yet") } rs, err := res.NextResultSet(ctx) @@ -71,14 +72,14 @@ func TestQueryTxExecute(t *testing.T) { }) t.Run("SerializableReadWrite", func(t *testing.T) { err := scope.DriverWithLogs().Query().DoTx(scope.Ctx, func(ctx context.Context, tx query.TxActor) (err error) { - if tx.ID() != internalQuery.LazyTxID { + if tx.ID() != baseTx.LazyTxID { return errors.New("transaction is not lazy") } res, err := tx.Query(ctx, "SELECT 1 AS col1") if err != nil { return err } - if tx.ID() == internalQuery.LazyTxID { + if tx.ID() == baseTx.LazyTxID { return errors.New("transaction is lazy yet") } rs, err := res.NextResultSet(ctx) @@ -107,14 +108,14 @@ func TestQueryTxExecute(t *testing.T) { }) t.Run("SnapshotReadOnly", func(t *testing.T) { err := scope.DriverWithLogs().Query().DoTx(scope.Ctx, func(ctx context.Context, tx query.TxActor) (err error) { - if tx.ID() != internalQuery.LazyTxID { + if tx.ID() != baseTx.LazyTxID { return errors.New("transaction is not lazy") } res, err := tx.Query(ctx, "SELECT 1 AS col1") if err != nil { return err } - if tx.ID() == internalQuery.LazyTxID { + if tx.ID() == baseTx.LazyTxID { return errors.New("transaction is lazy yet") } rs, err := res.NextResultSet(ctx) @@ -143,14 +144,14 @@ func TestQueryTxExecute(t *testing.T) { }) t.Run("OnlineReadOnly", func(t *testing.T) { err := scope.DriverWithLogs().Query().DoTx(scope.Ctx, func(ctx context.Context, tx query.TxActor) (err error) { - if tx.ID() != internalQuery.LazyTxID { + if tx.ID() != baseTx.LazyTxID { return errors.New("transaction is not lazy") } res, err := tx.Query(ctx, "SELECT 1 AS col1") if err != nil { return err } - if tx.ID() == internalQuery.LazyTxID { + if tx.ID() == baseTx.LazyTxID { return errors.New("transaction is lazy yet") } rs, err := res.NextResultSet(ctx) @@ -177,14 +178,14 @@ func TestQueryTxExecute(t *testing.T) { }) t.Run("StaleReadOnly", func(t *testing.T) { err := scope.DriverWithLogs().Query().DoTx(scope.Ctx, func(ctx context.Context, tx query.TxActor) (err error) { - if tx.ID() != internalQuery.LazyTxID { + if tx.ID() != baseTx.LazyTxID { return errors.New("transaction is not lazy") } res, err := tx.Query(ctx, "SELECT 1 AS col1") if err != nil { return err } - if tx.ID() == internalQuery.LazyTxID { + if tx.ID() == baseTx.LazyTxID { return errors.New("transaction is lazy yet") } rs, err := res.NextResultSet(ctx) @@ -211,7 +212,7 @@ func TestQueryTxExecute(t *testing.T) { }) t.Run("ErrOptionNotForTxExecute", func(t *testing.T) { err := scope.DriverWithLogs().Query().DoTx(scope.Ctx, func(ctx context.Context, tx query.TxActor) (err error) { - if tx.ID() != internalQuery.LazyTxID { + if tx.ID() != baseTx.LazyTxID { return errors.New("transaction is not lazy") } err = tx.Exec(ctx, "SELECT 1 AS col1",