From 674367f0c523a8c4432387f9a31727a193df7e6c Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Wed, 10 Apr 2024 15:44:52 +0800 Subject: [PATCH] add SyncTask's timeout setting (#5209) Signed-off-by: Future-Outlier --- .../go/tasks/plugins/webapi/agent/plugin.go | 8 +++++--- .../go/tasks/plugins/webapi/agent/plugin_test.go | 15 +++++++++++++-- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index cc7f15bd80..03c04b4d27 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -92,12 +92,11 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR taskCategory := admin.TaskCategory{Name: taskTemplate.Type, Version: taskTemplate.TaskTypeVersion} agent, isSync := getFinalAgent(&taskCategory, p.cfg, p.agentRegistry) - finalCtx, cancel := getFinalContext(ctx, "CreateTask", agent) - defer cancel() - taskExecutionMetadata := buildTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) if isSync { + finalCtx, cancel := getFinalContext(ctx, "ExecuteTaskSync", agent) + defer cancel() client, err := p.getSyncAgentClient(ctx, agent) if err != nil { return nil, nil, err @@ -106,6 +105,9 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR return p.ExecuteTaskSync(finalCtx, client, header, inputs) } + finalCtx, cancel := getFinalContext(ctx, "CreateTask", agent) + defer cancel() + // Use async agent client client, err := p.getAsyncAgentClient(ctx, agent) if err != nil { diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go index 9fa36c5c42..3e8cb882c8 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go @@ -75,15 +75,26 @@ func TestPlugin(t *testing.T) { t.Run("test getFinalTimeout", func(t *testing.T) { timeout := getFinalTimeout("CreateTask", &Deployment{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"CreateTask": {Duration: 1 * time.Millisecond}}}) assert.Equal(t, 1*time.Millisecond, timeout.Duration) + timeout = getFinalTimeout("GetTask", &Deployment{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"GetTask": {Duration: 1 * time.Millisecond}}}) + assert.Equal(t, 1*time.Millisecond, timeout.Duration) timeout = getFinalTimeout("DeleteTask", &Deployment{Endpoint: "localhost:8080", DefaultTimeout: config.Duration{Duration: 10 * time.Second}}) assert.Equal(t, 10*time.Second, timeout.Duration) + timeout = getFinalTimeout("ExecuteTaskSync", &Deployment{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"ExecuteTaskSync": {Duration: 1 * time.Millisecond}}}) + assert.Equal(t, 1*time.Millisecond, timeout.Duration) }) t.Run("test getFinalContext", func(t *testing.T) { - ctx, _ := getFinalContext(context.TODO(), "DeleteTask", &Deployment{}) + + ctx, _ := getFinalContext(context.TODO(), "CreateTask", &Deployment{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"CreateTask": {Duration: 1 * time.Millisecond}}}) + assert.NotEqual(t, context.TODO(), ctx) + + ctx, _ = getFinalContext(context.TODO(), "GetTask", &Deployment{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"GetTask": {Duration: 1 * time.Millisecond}}}) + assert.NotEqual(t, context.TODO(), ctx) + + ctx, _ = getFinalContext(context.TODO(), "DeleteTask", &Deployment{}) assert.Equal(t, context.TODO(), ctx) - ctx, _ = getFinalContext(context.TODO(), "CreateTask", &Deployment{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"CreateTask": {Duration: 1 * time.Millisecond}}}) + ctx, _ = getFinalContext(context.TODO(), "ExecuteTaskSync", &Deployment{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"ExecuteTaskSync": {Duration: 10 * time.Second}}}) assert.NotEqual(t, context.TODO(), ctx) })