Skip to content

Commit

Permalink
add SyncTask's timeout setting (#5209)
Browse files Browse the repository at this point in the history
Signed-off-by: Future-Outlier <[email protected]>
  • Loading branch information
Future-Outlier authored Apr 10, 2024
1 parent 6a39af7 commit 674367f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
8 changes: 5 additions & 3 deletions flyteplugins/go/tasks/plugins/webapi/agent/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,11 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR
taskCategory := admin.TaskCategory{Name: taskTemplate.Type, Version: taskTemplate.TaskTypeVersion}
agent, isSync := getFinalAgent(&taskCategory, p.cfg, p.agentRegistry)

finalCtx, cancel := getFinalContext(ctx, "CreateTask", agent)
defer cancel()

taskExecutionMetadata := buildTaskExecutionMetadata(taskCtx.TaskExecutionMetadata())

if isSync {
finalCtx, cancel := getFinalContext(ctx, "ExecuteTaskSync", agent)
defer cancel()
client, err := p.getSyncAgentClient(ctx, agent)
if err != nil {
return nil, nil, err
Expand All @@ -106,6 +105,9 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR
return p.ExecuteTaskSync(finalCtx, client, header, inputs)
}

finalCtx, cancel := getFinalContext(ctx, "CreateTask", agent)
defer cancel()

// Use async agent client
client, err := p.getAsyncAgentClient(ctx, agent)
if err != nil {
Expand Down
15 changes: 13 additions & 2 deletions flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,26 @@ func TestPlugin(t *testing.T) {
t.Run("test getFinalTimeout", func(t *testing.T) {
timeout := getFinalTimeout("CreateTask", &Deployment{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"CreateTask": {Duration: 1 * time.Millisecond}}})
assert.Equal(t, 1*time.Millisecond, timeout.Duration)
timeout = getFinalTimeout("GetTask", &Deployment{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"GetTask": {Duration: 1 * time.Millisecond}}})
assert.Equal(t, 1*time.Millisecond, timeout.Duration)
timeout = getFinalTimeout("DeleteTask", &Deployment{Endpoint: "localhost:8080", DefaultTimeout: config.Duration{Duration: 10 * time.Second}})
assert.Equal(t, 10*time.Second, timeout.Duration)
timeout = getFinalTimeout("ExecuteTaskSync", &Deployment{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"ExecuteTaskSync": {Duration: 1 * time.Millisecond}}})
assert.Equal(t, 1*time.Millisecond, timeout.Duration)
})

t.Run("test getFinalContext", func(t *testing.T) {
ctx, _ := getFinalContext(context.TODO(), "DeleteTask", &Deployment{})

ctx, _ := getFinalContext(context.TODO(), "CreateTask", &Deployment{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"CreateTask": {Duration: 1 * time.Millisecond}}})
assert.NotEqual(t, context.TODO(), ctx)

ctx, _ = getFinalContext(context.TODO(), "GetTask", &Deployment{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"GetTask": {Duration: 1 * time.Millisecond}}})
assert.NotEqual(t, context.TODO(), ctx)

ctx, _ = getFinalContext(context.TODO(), "DeleteTask", &Deployment{})
assert.Equal(t, context.TODO(), ctx)

ctx, _ = getFinalContext(context.TODO(), "CreateTask", &Deployment{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"CreateTask": {Duration: 1 * time.Millisecond}}})
ctx, _ = getFinalContext(context.TODO(), "ExecuteTaskSync", &Deployment{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"ExecuteTaskSync": {Duration: 10 * time.Second}}})
assert.NotEqual(t, context.TODO(), ctx)
})

Expand Down

0 comments on commit 674367f

Please sign in to comment.