diff --git a/flyteplugins/go/tasks/plugins/testing/echo.go b/flyteplugins/go/tasks/plugins/testing/echo.go index 885ab5dfc4..7c55d3862f 100644 --- a/flyteplugins/go/tasks/plugins/testing/echo.go +++ b/flyteplugins/go/tasks/plugins/testing/echo.go @@ -3,6 +3,7 @@ package testing import ( "context" "fmt" + "sync" "time" idlcore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" @@ -20,6 +21,7 @@ const ( type EchoPlugin struct { enqueueOwner core.EnqueueOwner taskStartTimes map[string]time.Time + sync.Mutex } func (e *EchoPlugin) GetID() string { @@ -30,9 +32,11 @@ func (e *EchoPlugin) GetProperties() core.PluginProperties { return core.PluginProperties{} } -func (e *EchoPlugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (core.Transition, error) { - echoConfig := ConfigSection.GetConfig().(*Config) - +// Enqueue the task to be re-evaluated after SleepDuration. +// If the task is already enqueued, return the start time of the task. +func (e *EchoPlugin) addTask(ctx context.Context, tCtx core.TaskExecutionContext) time.Time { + e.Lock() + defer e.Unlock() var startTime time.Time var exists bool taskExecutionID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() @@ -42,47 +46,34 @@ func (e *EchoPlugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext) // start timer to enqueue owner once task sleep duration has elapsed go func() { + echoConfig := ConfigSection.GetConfig().(*Config) time.Sleep(echoConfig.SleepDuration.Duration) if err := e.enqueueOwner(tCtx.TaskExecutionMetadata().GetOwnerID()); err != nil { logger.Warnf(ctx, "failed to enqueue owner [%s]: %v", tCtx.TaskExecutionMetadata().GetOwnerID(), err) } }() } + return startTime +} - if time.Since(startTime) >= echoConfig.SleepDuration.Duration { - // copy inputs to outputs - inputToOutputVariableMappings, err := compileInputToOutputVariableMappings(ctx, tCtx) - if err != nil { - return core.UnknownTransition, err - } - - if len(inputToOutputVariableMappings) > 0 { - inputLiterals, err := tCtx.InputReader().Get(ctx) - if err != nil { - return core.UnknownTransition, err - } - - outputLiterals := make(map[string]*idlcore.Literal, len(inputToOutputVariableMappings)) - for inputVariableName, outputVariableName := range inputToOutputVariableMappings { - outputLiterals[outputVariableName] = inputLiterals.Literals[inputVariableName] - } +// Remove the task from the taskStartTimes map. +func (e *EchoPlugin) removeTask(taskExecutionID string) { + e.Lock() + defer e.Unlock() + delete(e.taskStartTimes, taskExecutionID) +} - outputLiteralMap := &idlcore.LiteralMap{ - Literals: outputLiterals, - } +func (e *EchoPlugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (core.Transition, error) { + echoConfig := ConfigSection.GetConfig().(*Config) - outputFile := tCtx.OutputWriter().GetOutputPath() - if err := tCtx.DataStore().WriteProtobuf(ctx, outputFile, storage.Options{}, outputLiteralMap); err != nil { - return core.UnknownTransition, err - } + if echoConfig.SleepDuration.Duration == time.Duration(0) { + return copyInputsToOutputs(ctx, tCtx) + } - or := ioutils.NewRemoteFileOutputReader(ctx, tCtx.DataStore(), tCtx.OutputWriter(), 0) - if err = tCtx.OutputWriter().Put(ctx, or); err != nil { - return core.UnknownTransition, err - } - } + startTime := e.addTask(ctx, tCtx) - return core.DoTransition(core.PhaseInfoSuccess(nil)), nil + if time.Since(startTime) >= echoConfig.SleepDuration.Duration { + return copyInputsToOutputs(ctx, tCtx) } return core.DoTransition(core.PhaseInfoRunning(core.DefaultPhaseVersion, nil)), nil @@ -94,10 +85,45 @@ func (e *EchoPlugin) Abort(ctx context.Context, tCtx core.TaskExecutionContext) func (e *EchoPlugin) Finalize(ctx context.Context, tCtx core.TaskExecutionContext) error { taskExecutionID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() - delete(e.taskStartTimes, taskExecutionID) + e.removeTask(taskExecutionID) return nil } +// copyInputsToOutputs copies the input literals to the output location. +func copyInputsToOutputs(ctx context.Context, tCtx core.TaskExecutionContext) (core.Transition, error) { + inputToOutputVariableMappings, err := compileInputToOutputVariableMappings(ctx, tCtx) + if err != nil { + return core.UnknownTransition, err + } + + if len(inputToOutputVariableMappings) > 0 { + inputLiterals, err := tCtx.InputReader().Get(ctx) + if err != nil { + return core.UnknownTransition, err + } + + outputLiterals := make(map[string]*idlcore.Literal, len(inputToOutputVariableMappings)) + for inputVariableName, outputVariableName := range inputToOutputVariableMappings { + outputLiterals[outputVariableName] = inputLiterals.Literals[inputVariableName] + } + + outputLiteralMap := &idlcore.LiteralMap{ + Literals: outputLiterals, + } + + outputFile := tCtx.OutputWriter().GetOutputPath() + if err := tCtx.DataStore().WriteProtobuf(ctx, outputFile, storage.Options{}, outputLiteralMap); err != nil { + return core.UnknownTransition, err + } + + or := ioutils.NewRemoteFileOutputReader(ctx, tCtx.DataStore(), tCtx.OutputWriter(), 0) + if err = tCtx.OutputWriter().Put(ctx, or); err != nil { + return core.UnknownTransition, err + } + } + return core.DoTransition(core.PhaseInfoSuccess(nil)), nil +} + func compileInputToOutputVariableMappings(ctx context.Context, tCtx core.TaskExecutionContext) (map[string]string, error) { // validate outputs are castable from inputs otherwise error as this plugin is not applicable taskTemplate, err := tCtx.TaskReader().Read(ctx)