diff --git a/context.go b/context.go index 61786d29..70845c22 100644 --- a/context.go +++ b/context.go @@ -22,6 +22,9 @@ func withClient[TTx any](ctx context.Context, client *Client[TTx]) context.Conte // // When testing JobArgs.Work implementations, it might be useful to use // rivertest.WorkContext to initialize a context that has an available client. +// +// The type parameter TTx is the transaction type used by the [Client], +// pgx.Tx for the pgx driver, and *sql.Tx for the [database/sql] driver. func ClientFromContext[TTx any](ctx context.Context) *Client[TTx] { client, err := ClientFromContextSafely[TTx](ctx) if err != nil { @@ -39,6 +42,9 @@ func ClientFromContext[TTx any](ctx context.Context) *Client[TTx] { // // When testing JobArgs.Work implementations, it might be useful to use // rivertest.WorkContext to initialize a context that has an available client. +// +// See the examples for [ClientFromContext] to understand how to use this +// function. func ClientFromContextSafely[TTx any](ctx context.Context) (*Client[TTx], error) { client, exists := ctx.Value(rivercommon.ContextKeyClient{}).(*Client[TTx]) if !exists || client == nil { diff --git a/example_client_from_context_dbsql_test.go b/example_client_from_context_dbsql_test.go new file mode 100644 index 00000000..f306e914 --- /dev/null +++ b/example_client_from_context_dbsql_test.go @@ -0,0 +1,88 @@ +package river_test + +import ( + "context" + "database/sql" + "errors" + "fmt" + "log/slog" + "time" + + _ "github.com/jackc/pgx/v5/stdlib" + + "github.com/riverqueue/river" + "github.com/riverqueue/river/internal/riverinternaltest" + "github.com/riverqueue/river/riverdriver/riverdatabasesql" + "github.com/riverqueue/river/rivershared/util/slogutil" +) + +type ContextClientSQLArgs struct{} + +func (args ContextClientSQLArgs) Kind() string { return "ContextClientSQLWorker" } + +type ContextClientSQLWorker struct { + river.WorkerDefaults[ContextClientSQLArgs] +} + +func (w *ContextClientSQLWorker) Work(ctx context.Context, job *river.Job[ContextClientSQLArgs]) error { + client := river.ClientFromContext[*sql.Tx](ctx) + if client == nil { + fmt.Println("client not found in context") + return errors.New("client not found in context") + } + + fmt.Printf("client found in context, id=%s\n", client.ID()) + return nil +} + +// ExampleClientFromContext_databaseSQL demonstrates how to extract the River +// client from the worker context when using the [database/sql] driver. +// ([github.com/riverqueue/river/riverdriver/riverdatabasesql]) +func ExampleClientFromContext_databaseSQL() { + ctx := context.Background() + + config := riverinternaltest.DatabaseConfig("river_test_example") + db, err := sql.Open("pgx", config.ConnString()) + if err != nil { + panic(err) + } + defer db.Close() + + workers := river.NewWorkers() + river.AddWorker(workers, &ContextClientSQLWorker{}) + + riverClient, err := river.NewClient(riverdatabasesql.New(db), &river.Config{ + ID: "ClientFromContextClientSQL", + Logger: slog.New(&slogutil.SlogMessageOnlyHandler{Level: slog.LevelWarn}), + Queues: map[string]river.QueueConfig{ + river.QueueDefault: {MaxWorkers: 10}, + }, + FetchCooldown: 10 * time.Millisecond, + FetchPollInterval: 10 * time.Millisecond, + TestOnly: true, // suitable only for use in tests; remove for live environments + Workers: workers, + }) + if err != nil { + panic(err) + } + + // Not strictly needed, but used to help this test wait until job is worked. + subscribeChan, subscribeCancel := riverClient.Subscribe(river.EventKindJobCompleted) + defer subscribeCancel() + + if err := riverClient.Start(ctx); err != nil { + panic(err) + } + if _, err := riverClient.Insert(ctx, ContextClientSQLArgs{}, nil); err != nil { + panic(err) + } + + waitForNJobs(subscribeChan, 1) + + if err := riverClient.Stop(ctx); err != nil { + panic(err) + } + + // Output: + // client found in context, id=ClientFromContextClientSQL +} diff --git a/example_client_from_context_test.go b/example_client_from_context_test.go new file mode 100644 index 00000000..d40f690d --- /dev/null +++ b/example_client_from_context_test.go @@ -0,0 +1,89 @@ +package river_test + +import ( + "context" + "errors" + "fmt" + "log/slog" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/riverqueue/river" + "github.com/riverqueue/river/internal/riverinternaltest" + "github.com/riverqueue/river/riverdriver/riverpgxv5" + "github.com/riverqueue/river/rivershared/util/slogutil" +) + +type ContextClientArgs struct{} + +func (args ContextClientArgs) Kind() string { return "ContextClientWorker" } + +type ContextClientWorker struct { + river.WorkerDefaults[ContextClientArgs] +} + +func (w *ContextClientWorker) Work(ctx context.Context, job *river.Job[ContextClientArgs]) error { + client := river.ClientFromContext[pgx.Tx](ctx) + if client == nil { + fmt.Println("client not found in context") + return errors.New("client not found in context") + } + + fmt.Printf("client found in context, id=%s\n", client.ID()) + return nil +} + +// ExampleClientFromContext_pgx demonstrates how to extract the River client +// from the worker context when using the pgx/v5 driver. +// ([github.com/riverqueue/river/riverdriver/riverpgxv5]) +func ExampleClientFromContext_pgx() { + ctx := context.Background() + + dbPool, err := pgxpool.NewWithConfig(ctx, riverinternaltest.DatabaseConfig("river_test_example")) + if err != nil { + panic(err) + } + defer dbPool.Close() + + // Required for the purpose of this test, but not necessary in real usage. + if err := riverinternaltest.TruncateRiverTables(ctx, dbPool); err != nil { + panic(err) + } + + workers := river.NewWorkers() + river.AddWorker(workers, &ContextClientWorker{}) + + riverClient, err := river.NewClient(riverpgxv5.New(dbPool), &river.Config{ + ID: "ClientFromContextClient", + Logger: slog.New(&slogutil.SlogMessageOnlyHandler{Level: slog.LevelWarn}), + Queues: map[string]river.QueueConfig{ + river.QueueDefault: {MaxWorkers: 10}, + }, + TestOnly: true, // suitable only for use in tests; remove for live environments + Workers: workers, + }) + if err != nil { + panic(err) + } + + // Not strictly needed, but used to help this test wait until job is worked. + subscribeChan, subscribeCancel := riverClient.Subscribe(river.EventKindJobCompleted) + defer subscribeCancel() + + if err := riverClient.Start(ctx); err != nil { + panic(err) + } + if _, err = riverClient.Insert(ctx, ContextClientArgs{}, nil); err != nil { + panic(err) + } + + waitForNJobs(subscribeChan, 1) + + if err := riverClient.Stop(ctx); err != nil { + panic(err) + } + + // Output: + // client found in context, id=ClientFromContextClient +} diff --git a/go.work.sum b/go.work.sum index 4aaf074b..545b9d94 100644 --- a/go.work.sum +++ b/go.work.sum @@ -4,7 +4,6 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/kr/pty v1.1.1 h1:VkoXIwSboBpnk99O/KFauAEILuNHv5DVFKZMBN/gUgw= github.com/riverqueue/river v0.13.0-rc.1/go.mod h1:ZxTeoNZWNpwl+dCBWF5AomGV1exZbHu/E75ufb09HIo= -github.com/riverqueue/river/riverdriver/riverdatabasesql v0.13.0/go.mod h1:f7TWWD965tE6v96qi1Y40IP2shsAai0qJBHbqT7yFLM= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -19,6 +18,7 @@ golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0= golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= @@ -29,6 +29,7 @@ golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM= golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2 h1:IRJeR9r1pYWsHKTRe/IInb7lYvbBVIqOgsX/u0mbOWY= golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= @@ -36,6 +37,7 @@ golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU= golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk= golang.org/x/term v0.24.0/go.mod h1:lOBK/LVxemqiMij05LGJ0tzNr8xlmwBRJ81PX6wVLH8= +golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24= golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=