Skip to content

Commit

Permalink
Transfer commits
Browse files Browse the repository at this point in the history
Signed-off-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
eapolinario committed Oct 3, 2023
2 parents d9586b0 + 536f80c commit ed2f143
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 20 deletions.
6 changes: 6 additions & 0 deletions flyteplugins/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ require (
sigs.k8s.io/yaml v1.3.0 // indirect
)

<<<<<<< HEAD
replace (
github.com/aws/amazon-sagemaker-operator-for-k8s => github.com/aws/amazon-sagemaker-operator-for-k8s v1.0.1-0.20210303003444-0fb33b1fd49d
github.com/flyteorg/flyte/datacatalog => ../datacatalog
Expand All @@ -144,3 +145,8 @@ replace (
github.com/flyteorg/flyte/flytestdlib => ../flytestdlib
github.com/flyteorg/flyteidl => ../flyteidl
)
=======
replace github.com/aws/amazon-sagemaker-operator-for-k8s => github.com/aws/amazon-sagemaker-operator-for-k8s v1.0.1-0.20210303003444-0fb33b1fd49d

replace github.com/flyteorg/flyteidl => /Users/andrew/dev/forks/flyteidl
>>>>>>> flyteplugins/spark-pod-templates
5 changes: 5 additions & 0 deletions flyteplugins/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,11 @@ github.com/evanphx/json-patch v4.12.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQL
github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w=
github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk=
<<<<<<< HEAD
=======
github.com/flyteorg/flytestdlib v1.0.24 h1:jDvymcjlsTRCwOtxPapro0WZBe3isTz+T3Tiq+mZUuk=
github.com/flyteorg/flytestdlib v1.0.24/go.mod h1:6nXa5g00qFIsgdvQ7jKQMJmDniqO0hG6Z5X5olfduqQ=
>>>>>>> flyteplugins/spark-pod-templates
github.com/flyteorg/stow v0.3.7 h1:Cx7j8/Ux6+toD5hp5fy++927V+yAcAttDeQAlUD/864=
github.com/flyteorg/stow v0.3.7/go.mod h1:5dfBitPM004dwaZdoVylVjxFT4GWAgI0ghAndhNUzCo=
github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k=
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,6 @@ type TaskExecutionMetadata interface {
GetSecurityContext() core.SecurityContext
IsInterruptible() bool
GetPlatformResources() *v1.ResourceRequirements
GetInterruptibleFailureThreshold() uint32
GetInterruptibleFailureThreshold() int32
GetEnvironmentVariables() map[string]string
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ func NewSubTaskExecutionMetadata(taskExecutionMetadata pluginsCore.TaskExecution
}

subTaskExecutionID := NewSubTaskExecutionID(taskExecutionMetadata.GetTaskExecutionID(), executionIndex, retryAttempt)
interruptible := taskExecutionMetadata.IsInterruptible() && uint32(systemFailures) < taskExecutionMetadata.GetInterruptibleFailureThreshold()
interruptible := taskExecutionMetadata.IsInterruptible() && int32(systemFailures) < taskExecutionMetadata.GetInterruptibleFailureThreshold()
return SubTaskExecutionMetadata{
taskExecutionMetadata,
utils.UnionMaps(taskExecutionMetadata.GetAnnotations(), secretsMap),
Expand Down
40 changes: 37 additions & 3 deletions flyteplugins/go/tasks/plugins/k8s/spark/spark.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"

"google.golang.org/protobuf/types/known/structpb"
"sigs.k8s.io/controller-runtime/pkg/client"

"strconv"
Expand All @@ -20,14 +21,22 @@ import (
"github.com/flyteorg/flyte/flyteplugins/go/tasks/logs"
pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core"

<<<<<<< HEAD
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils"
=======
v1 "k8s.io/api/core/v1"
>>>>>>> flyteplugins/spark-pod-templates
"k8s.io/client-go/kubernetes/scheme"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils"

sparkOp "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"regexp"
"strings"
Expand Down Expand Up @@ -61,6 +70,21 @@ func (sparkResourceHandler) GetProperties() k8s.PluginProperties {
return k8s.PluginProperties{}
}

func getTolerations(podSpecPb *structpb.Struct) ([]v1.Toleration, error) {
tolerations := make([]v1.Toleration, 0)
tolerations = append(tolerations, config.GetK8sPluginConfig().DefaultTolerations...)
if podSpecPb != nil {
var podSpec v1.PodSpec
err := utils.UnmarshalStruct(podSpecPb, &podSpec)
if err != nil {
return nil, errors.Wrapf(errors.BadTaskSpecification, err,
"invalid pod spec [%v], failed to unmarshal", podSpec)
}
tolerations = append(tolerations, podSpec.Tolerations...)
}
return tolerations, nil
}

// Creates a new Job that will execute the main container as well as any generated types the result from the execution.
func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) {
taskTemplate, err := taskCtx.TaskReader().Read(ctx)
Expand Down Expand Up @@ -99,6 +123,11 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
if len(serviceAccountName) == 0 {
serviceAccountName = sparkTaskType
}

tolerations, err := getTolerations(sparkJob.GetDriverPod().GetPodSpec())
if err != nil {
return nil, err
}
driverSpec := sparkOp.DriverSpec{
SparkPodSpec: sparkOp.SparkPodSpec{
Affinity: config.GetK8sPluginConfig().DefaultAffinity,
Expand All @@ -108,14 +137,19 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
Image: &container.Image,
SecurityContenxt: config.GetK8sPluginConfig().DefaultPodSecurityContext.DeepCopy(),
DNSConfig: config.GetK8sPluginConfig().DefaultPodDNSConfig.DeepCopy(),
Tolerations: config.GetK8sPluginConfig().DefaultTolerations,
Tolerations: tolerations,
SchedulerName: &config.GetK8sPluginConfig().SchedulerName,
NodeSelector: config.GetK8sPluginConfig().DefaultNodeSelector,
HostNetwork: config.GetK8sPluginConfig().EnableHostNetworkingPod,
},
ServiceAccount: &serviceAccountName,
}

tolerations, err = getTolerations(sparkJob.GetExecutorPod().GetPodSpec())
if err != nil {
return nil, err
}

executorSpec := sparkOp.ExecutorSpec{
SparkPodSpec: sparkOp.SparkPodSpec{
Affinity: config.GetK8sPluginConfig().DefaultAffinity.DeepCopy(),
Expand All @@ -125,7 +159,7 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
EnvVars: sparkEnvVars,
SecurityContenxt: config.GetK8sPluginConfig().DefaultPodSecurityContext.DeepCopy(),
DNSConfig: config.GetK8sPluginConfig().DefaultPodDNSConfig.DeepCopy(),
Tolerations: config.GetK8sPluginConfig().DefaultTolerations,
Tolerations: tolerations,
SchedulerName: &config.GetK8sPluginConfig().SchedulerName,
NodeSelector: config.GetK8sPluginConfig().DefaultNodeSelector,
HostNetwork: config.GetK8sPluginConfig().EnableHostNetworkingPod,
Expand Down
92 changes: 82 additions & 10 deletions flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@ import (
pluginIOMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks"

sj "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"
"github.com/golang/protobuf/jsonpb"
structpb "github.com/golang/protobuf/ptypes/struct"
"github.com/stretchr/testify/assert"
corev1 "k8s.io/api/core/v1"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"
)

const sparkMainClass = "MainClass"
Expand Down Expand Up @@ -87,7 +88,8 @@ func TestGetEventInfo(t *testing.T) {
},
},
}))
taskCtx := dummySparkTaskContext(dummySparkTaskTemplate("blah-1", dummySparkConf), false)
sparkJob := dummySparkCustomObj(dummySparkConf)
taskCtx := dummySparkTaskContext(dummySparkTaskTemplate("blah-1", sparkJob), false)
info, err := getEventInfoForSpark(taskCtx, dummySparkApplication(sj.RunningState))
assert.NoError(t, err)
assert.Len(t, info.Logs, 6)
Expand Down Expand Up @@ -157,7 +159,8 @@ func TestGetTaskPhase(t *testing.T) {
sparkResourceHandler := sparkResourceHandler{}

ctx := context.TODO()
taskCtx := dummySparkTaskContext(dummySparkTaskTemplate("", dummySparkConf), false)
sparkJob := dummySparkCustomObj(dummySparkConf)
taskCtx := dummySparkTaskContext(dummySparkTaskTemplate("", sparkJob), false)
taskPhase, err := sparkResourceHandler.GetTaskPhase(ctx, taskCtx, dummySparkApplication(sj.NewState))
assert.NoError(t, err)
assert.Equal(t, taskPhase.Phase(), pluginsCore.PhaseQueued)
Expand Down Expand Up @@ -242,17 +245,14 @@ func dummySparkApplication(state sj.ApplicationStateType) *sj.SparkApplication {

func dummySparkCustomObj(sparkConf map[string]string) *plugins.SparkJob {
sparkJob := plugins.SparkJob{}

sparkJob.MainClass = sparkMainClass
sparkJob.MainApplicationFile = sparkApplicationFile
sparkJob.SparkConf = sparkConf
sparkJob.ApplicationType = plugins.SparkApplication_PYTHON
return &sparkJob
}

func dummySparkTaskTemplate(id string, sparkConf map[string]string) *core.TaskTemplate {

sparkJob := dummySparkCustomObj(sparkConf)
func dummySparkTaskTemplate(id string, sparkJob *plugins.SparkJob) *core.TaskTemplate {
sparkJobJSON, err := utils.MarshalToString(sparkJob)
if err != nil {
panic(err)
Expand Down Expand Up @@ -335,7 +335,8 @@ func TestBuildResourceSpark(t *testing.T) {
sparkResourceHandler := sparkResourceHandler{}

// Case1: Valid Spark Task-Template
taskTemplate := dummySparkTaskTemplate("blah-1", dummySparkConf)
sparkJob := dummySparkCustomObj(dummySparkConf)
taskTemplate := dummySparkTaskTemplate("blah-1", sparkJob)

// Set spark custom feature config.
assert.NoError(t, setSparkConfig(&Config{
Expand Down Expand Up @@ -619,7 +620,8 @@ func TestBuildResourceSpark(t *testing.T) {
dummyConfWithRequest["spark.kubernetes.driver.request.cores"] = "3"
dummyConfWithRequest["spark.kubernetes.executor.request.cores"] = "4"

taskTemplate = dummySparkTaskTemplate("blah-1", dummyConfWithRequest)
sparkJob = dummySparkCustomObj(dummyConfWithRequest)
taskTemplate = dummySparkTaskTemplate("blah-1", sparkJob)
resource, err = sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, false))
assert.Nil(t, err)
assert.NotNil(t, resource)
Expand Down Expand Up @@ -678,6 +680,76 @@ func TestBuildResourceSpark(t *testing.T) {
assert.Nil(t, resource)
}

func TestBuildResourcePodTemplate(t *testing.T) {
defaultToleration := corev1.Toleration{

Key: "x/flyte",
Value: "default",
Operator: "Equal",
}
err := config.SetK8sPluginConfig(&config.K8sPluginConfig{
DefaultTolerations: []corev1.Toleration{
defaultToleration,
},
})
assert.NoError(t, err)
sparkJob := dummySparkCustomObj(dummySparkConf)
extraDriverToleration := corev1.Toleration{
Key: "x/flyte",
Value: "extra-driver",
Operator: "Equal",
}
podSpec := corev1.PodSpec{
Tolerations: []corev1.Toleration{
extraDriverToleration,
},
}
driverPodSpecPb := structpb.Struct{}
err = utils.MarshalStruct(&podSpec, &driverPodSpecPb)
assert.NoError(t, err)
sparkJob.DriverPodValue = &plugins.SparkJob_DriverPod{
DriverPod: &core.K8SPod{
PodSpec: &driverPodSpecPb,
},
}
extraExecutorToleration := corev1.Toleration{
Key: "x/flyte",
Value: "extra-executor",
Operator: "Equal",
}
podSpec = corev1.PodSpec{
Tolerations: []corev1.Toleration{
extraExecutorToleration,
},
}
execPodSpecPb := structpb.Struct{}
err = utils.MarshalStruct(&podSpec, &execPodSpecPb)
assert.NoError(t, err)
sparkJob.ExecutorPodValue = &plugins.SparkJob_ExecutorPod{
ExecutorPod: &core.K8SPod{
PodSpec: &execPodSpecPb,
},
}
taskTemplate := dummySparkTaskTemplate("blah-1", sparkJob)
sparkResourceHandler := sparkResourceHandler{}
resource, err := sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, false))
assert.Nil(t, err)

assert.NotNil(t, resource)
sparkApp, ok := resource.(*sj.SparkApplication)
assert.True(t, ok)
assert.Equal(t, 2, len(sparkApp.Spec.Driver.Tolerations))
assert.Equal(t, sparkApp.Spec.Driver.Tolerations, []corev1.Toleration{
defaultToleration,
extraDriverToleration,
})
assert.Equal(t, 2, len(sparkApp.Spec.Executor.Tolerations))
assert.Equal(t, sparkApp.Spec.Executor.Tolerations, []corev1.Toleration{
defaultToleration,
extraExecutorToleration,
})
}

func TestGetPropertiesSpark(t *testing.T) {
sparkResourceHandler := sparkResourceHandler{}
expected := k8s.PluginProperties{}
Expand Down

0 comments on commit ed2f143

Please sign in to comment.