Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add support raw container in the map task #329

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
6 changes: 4 additions & 2 deletions go/tasks/pluginmachinery/flytek8s/pod_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ const SIGKILL = 137
const defaultContainerTemplateName = "default"
const primaryContainerTemplateName = "primary"
const PrimaryContainerKey = "primary_container_name"
const FlyteCopilotName = "flyte_copilot_name"
const Sidecar = "sidecar"

// ApplyInterruptibleNodeSelectorRequirement configures the node selector requirement of the node-affinity using the configuration specified.
func ApplyInterruptibleNodeSelectorRequirement(interruptible bool, affinity *v1.Affinity) {
Expand Down Expand Up @@ -550,10 +552,10 @@ func DemystifySuccess(status v1.PodStatus, info pluginsCore.TaskInfo) (pluginsCo
return pluginsCore.PhaseInfoSuccess(&info), nil
}

// DeterminePrimaryContainerPhase as the name suggests, given all the containers, will return a pluginsCore.PhaseInfo object
// DetermineContainerPhase as the name suggests, given all the containers, will return a pluginsCore.PhaseInfo object
// corresponding to the phase of the primaryContainer which is identified using the provided name.
// This is useful in case of sidecars or pod jobs, where Flyte will monitor successful exit of a single container.
func DeterminePrimaryContainerPhase(primaryContainerName string, statuses []v1.ContainerStatus, info *pluginsCore.TaskInfo) pluginsCore.PhaseInfo {
func DetermineContainerPhase(primaryContainerName string, statuses []v1.ContainerStatus, info *pluginsCore.TaskInfo) pluginsCore.PhaseInfo {
for _, s := range statuses {
if s.Name == primaryContainerName {
if s.State.Waiting != nil || s.State.Running != nil {
Expand Down
10 changes: 5 additions & 5 deletions go/tasks/pluginmachinery/flytek8s/pod_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,7 @@ func TestDeterminePrimaryContainerPhase(t *testing.T) {
}
var info = &pluginsCore.TaskInfo{}
t.Run("primary container waiting", func(t *testing.T) {
phaseInfo := DeterminePrimaryContainerPhase(primaryContainerName, []v1.ContainerStatus{
phaseInfo := DetermineContainerPhase(primaryContainerName, []v1.ContainerStatus{
secondaryContainer, {
Name: primaryContainerName,
State: v1.ContainerState{
Expand All @@ -962,7 +962,7 @@ func TestDeterminePrimaryContainerPhase(t *testing.T) {
assert.Equal(t, pluginsCore.PhaseRunning, phaseInfo.Phase())
})
t.Run("primary container running", func(t *testing.T) {
phaseInfo := DeterminePrimaryContainerPhase(primaryContainerName, []v1.ContainerStatus{
phaseInfo := DetermineContainerPhase(primaryContainerName, []v1.ContainerStatus{
secondaryContainer, {
Name: primaryContainerName,
State: v1.ContainerState{
Expand All @@ -975,7 +975,7 @@ func TestDeterminePrimaryContainerPhase(t *testing.T) {
assert.Equal(t, pluginsCore.PhaseRunning, phaseInfo.Phase())
})
t.Run("primary container failed", func(t *testing.T) {
phaseInfo := DeterminePrimaryContainerPhase(primaryContainerName, []v1.ContainerStatus{
phaseInfo := DetermineContainerPhase(primaryContainerName, []v1.ContainerStatus{
secondaryContainer, {
Name: primaryContainerName,
State: v1.ContainerState{
Expand All @@ -992,7 +992,7 @@ func TestDeterminePrimaryContainerPhase(t *testing.T) {
assert.Equal(t, "foo failed", phaseInfo.Err().Message)
})
t.Run("primary container succeeded", func(t *testing.T) {
phaseInfo := DeterminePrimaryContainerPhase(primaryContainerName, []v1.ContainerStatus{
phaseInfo := DetermineContainerPhase(primaryContainerName, []v1.ContainerStatus{
secondaryContainer, {
Name: primaryContainerName,
State: v1.ContainerState{
Expand All @@ -1005,7 +1005,7 @@ func TestDeterminePrimaryContainerPhase(t *testing.T) {
assert.Equal(t, pluginsCore.PhaseSuccess, phaseInfo.Phase())
})
t.Run("missing primary container", func(t *testing.T) {
phaseInfo := DeterminePrimaryContainerPhase(primaryContainerName, []v1.ContainerStatus{
phaseInfo := DetermineContainerPhase(primaryContainerName, []v1.ContainerStatus{
secondaryContainer,
}, info)
assert.Equal(t, pluginsCore.PhasePermanentFailure, phaseInfo.Phase())
Expand Down
2 changes: 1 addition & 1 deletion go/tasks/plugins/array/k8s/management.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon
phaseInfo = core.PhaseInfoWaitingForResourcesInfo(time.Now(), core.DefaultPhaseVersion, "Exceeded ResourceManager quota", nil)
} else {
phaseInfo, perr = launchSubtask(ctx, stCtx, config, kubeClient)

logger.Infof(ctx, "Failed to launch subtask with error [%s]", perr)
// if launchSubtask fails we attempt to deallocate the (previously allocated)
// resource to mitigate leaks
if perr != nil {
Expand Down
14 changes: 13 additions & 1 deletion go/tasks/plugins/array/k8s/subtask.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ import (
"strings"
"time"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s"

"github.com/flyteorg/flyteplugins/go/tasks/errors"
pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/tasklog"
Expand Down Expand Up @@ -186,6 +187,17 @@ func launchSubtask(ctx context.Context, stCtx SubTaskExecutionContext, cfg *Conf
Value: strconv.Itoa(stCtx.originalIndex),
})

for sidecarIndex, container := range pod.Spec.Containers {
if container.Name == config.GetK8sPluginConfig().CoPilot.NamePrefix+flytek8s.Sidecar {
for i, arg := range pod.Spec.Containers[sidecarIndex].Args {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @hamersaw should we pass env FlyteK8sArrayIndex to copilot, and construct final output prefix in the copilot?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have strong feelings here. How are we passing the array index to the inputs downloader? Because in flytekit we pass the input data ref and a subtask index, IIUC it reads the full list of inputs and only uses the value at the subtask index. We need to do the same thing here right?

Copy link
Member Author

@pingsutw pingsutw Mar 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we pass array index to primary container instead of downloader, and raw container task will read the value at subtask index. here is an example. flyteorg/flytekit#1547.

The problem is that in the regular map task, we construct the final output prefix in the flytekit (output_prefix + array index.), but the raw container doesn't know the output prefix, it write to a local share dir instead. uploader will read the data in the share dir and upload to s3.

if arg == "--to-output-prefix" {
pod.Spec.Containers[sidecarIndex].Args[i+1] = fmt.Sprintf("%s/%s", pod.Spec.Containers[sidecarIndex].Args[i+1], strconv.Itoa(stCtx.originalIndex))
}
}
break
}
}

pod.Spec.Containers[containerIndex].Env = append(pod.Spec.Containers[containerIndex].Env, arrayJobEnvVars...)

logger.Infof(ctx, "Creating Object: Type:[%v], Object:[%v/%v]", pod.GetObjectKind().GroupVersionKind(), pod.GetNamespace(), pod.GetName())
Expand Down
18 changes: 17 additions & 1 deletion go/tasks/plugins/k8s/pod/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package pod
import (
"context"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config"

pluginserrors "github.com/flyteorg/flyteplugins/go/tasks/errors"
"github.com/flyteorg/flyteplugins/go/tasks/logs"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery"
Expand Down Expand Up @@ -132,6 +134,11 @@ func (p plugin) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecu
pod.ObjectMeta = *objectMeta
pod.Spec = *podSpec

if taskTemplate.GetContainer() != nil && taskTemplate.GetContainer().DataConfig != nil && taskTemplate.GetContainer().DataConfig.Enabled {
pod.Annotations[flytek8s.PrimaryContainerKey] = primaryContainerName
pod.Annotations[flytek8s.FlyteCopilotName] = config.GetK8sPluginConfig().CoPilot.NamePrefix + flytek8s.Sidecar
}

return pod, nil
}

Expand Down Expand Up @@ -184,8 +191,17 @@ func (plugin) GetTaskPhaseWithLogs(ctx context.Context, pluginContext k8s.Plugin
return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, &info), nil
}

// When the copilot is running, we should wait until the data is uploaded by the copilot.
copilotContainerName, exists := r.GetAnnotations()[flytek8s.FlyteCopilotName]
if exists {
copilotContainerPhase := flytek8s.DetermineContainerPhase(copilotContainerName, pod.Status.ContainerStatuses, &info)
if copilotContainerPhase.Phase() == pluginsCore.PhaseRunning && len(info.Logs) > 0 {
return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion+1, copilotContainerPhase.Info()), nil
}
}

Copy link
Contributor

@hamersaw hamersaw Mar 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So right now this is done in ContainerTasks (not using map task) by just not setting a primaryContainerName on the Pod, then this code waits for the entire Pod to complete. It seems like this is what we should do for subtasks as well.

I think the issue is that here we always add a PrimaryContainerName to the pod annotation. Maybe it makes sense to update the code in the subtask.go so that this annotation is only added if necessary, then we shouldn't need to add to flytek8s.FlyteCopilotName annotation above either.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is the propeller will only pass array index to primary container, so we have to set raw-container to primary. However, if we set it to primary, propeller won't wait for the uploader complete, so I added flytek8s.FlyteCopilotName to annotation, and wait for copilot first if we find the uploader container in the pod here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so we should probably update this logic then to support ContainerTask (ie. add array index to non-copilot containers) so that it doesn't rely on the primaryContainerName annotation.

This means we could keep the logic in PodPlugin so that if the primaryContainerName annotation exists, it waits for that container to completed. If it doesn't then it waits for the Pod to complete. It helps if this logic is simple because we have a few perf ideas to layer on top of it.

// if the primary container annotation exists, we use the status of the specified container
primaryContainerPhase := flytek8s.DeterminePrimaryContainerPhase(primaryContainerName, pod.Status.ContainerStatuses, &info)
primaryContainerPhase := flytek8s.DetermineContainerPhase(primaryContainerName, pod.Status.ContainerStatuses, &info)
if primaryContainerPhase.Phase() == pluginsCore.PhaseRunning && len(info.Logs) > 0 {
return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion+1, primaryContainerPhase.Info()), nil
}
Expand Down