Skip to content

Commit

Permalink
fix(backend): fixes "cannot save parameter" error message. Fixes kube…
Browse files Browse the repository at this point in the history
…flow#9678 (kubeflow#10459)

Signed-off-by: hbelmiro <[email protected]>
(cherry picked from commit 1ae0a82)
  • Loading branch information
hbelmiro committed Feb 22, 2024
1 parent 57830bf commit 90a92f9
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 10 deletions.
9 changes: 9 additions & 0 deletions backend/src/v2/cmd/driver/execution_paths.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package main

type ExecutionPaths struct {
ExecutionID string
IterationCount string
CachedDecision string
Condition string
PodSpecPatch string
}
48 changes: 38 additions & 10 deletions backend/src/v2/cmd/driver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ import (

const (
driverTypeArg = "type"
ROOT_DAG = "ROOT_DAG"
DAG = "DAG"
CONTAINER = "CONTAINER"
)

var (
Expand Down Expand Up @@ -160,12 +163,12 @@ func drive() (err error) {
var execution *driver.Execution
var driverErr error
switch *driverType {
case "ROOT_DAG":
case ROOT_DAG:
options.RuntimeConfig = runtimeConfig
execution, driverErr = driver.RootDAG(ctx, options, client)
case "DAG":
case DAG:
execution, driverErr = driver.DAG(ctx, options, client)
case "CONTAINER":
case CONTAINER:
options.Container = containerSpec
options.KubernetesExecutorConfig = k8sExecCfg
execution, driverErr = driver.Container(ctx, options, client, cacheClient)
Expand All @@ -183,35 +186,60 @@ func drive() (err error) {
err = driverErr
}()
}

executionPaths := &ExecutionPaths{
ExecutionID: *executionIDPath,
IterationCount: *iterationCountPath,
CachedDecision: *cachedDecisionPath,
Condition: *conditionPath,
PodSpecPatch: *podSpecPatchPath}

return handleExecution(execution, *driverType, executionPaths)
}

func handleExecution(execution *driver.Execution, driverType string, executionPaths *ExecutionPaths) error {
if execution.ID != 0 {
glog.Infof("output execution.ID=%v", execution.ID)
if *executionIDPath != "" {
if err = writeFile(*executionIDPath, []byte(fmt.Sprint(execution.ID))); err != nil {
if executionPaths.ExecutionID != "" {
if err := writeFile(executionPaths.ExecutionID, []byte(fmt.Sprint(execution.ID))); err != nil {
return fmt.Errorf("failed to write execution ID to file: %w", err)
}
}
}
if execution.IterationCount != nil {
if err = writeFile(*iterationCountPath, []byte(fmt.Sprintf("%v", *execution.IterationCount))); err != nil {
if err := writeFile(executionPaths.IterationCount, []byte(fmt.Sprintf("%v", *execution.IterationCount))); err != nil {
return fmt.Errorf("failed to write iteration count to file: %w", err)
}
} else {
if driverType == ROOT_DAG {
if err := writeFile(executionPaths.IterationCount, []byte("0")); err != nil {
return fmt.Errorf("failed to write iteration count to file: %w", err)
}
}
}
if execution.Cached != nil {
if err = writeFile(*cachedDecisionPath, []byte(strconv.FormatBool(*execution.Cached))); err != nil {
if err := writeFile(executionPaths.CachedDecision, []byte(strconv.FormatBool(*execution.Cached))); err != nil {
return fmt.Errorf("failed to write cached decision to file: %w", err)
}
}
if execution.Condition != nil {
if err = writeFile(*conditionPath, []byte(strconv.FormatBool(*execution.Condition))); err != nil {
if err := writeFile(executionPaths.Condition, []byte(strconv.FormatBool(*execution.Condition))); err != nil {
return fmt.Errorf("failed to write condition to file: %w", err)
}
} else {
// nil is a valid value for Condition
if driverType == ROOT_DAG || driverType == CONTAINER {
if err := writeFile(executionPaths.Condition, []byte("nil")); err != nil {
return fmt.Errorf("failed to write condition to file: %w", err)
}
}
}
if execution.PodSpecPatch != "" {
glog.Infof("output podSpecPatch=\n%s\n", execution.PodSpecPatch)
if *podSpecPatchPath == "" {
if executionPaths.PodSpecPatch == "" {
return fmt.Errorf("--pod_spec_patch_path is required for container executor drivers")
}
if err = writeFile(*podSpecPatchPath, []byte(execution.PodSpecPatch)); err != nil {
if err := writeFile(executionPaths.PodSpecPatch, []byte(execution.PodSpecPatch)); err != nil {
return fmt.Errorf("failed to write pod spec patch to file: %w", err)
}
}
Expand Down
79 changes: 79 additions & 0 deletions backend/src/v2/cmd/driver/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package main

import (
"github.com/kubeflow/pipelines/backend/src/v2/driver"
"os"
"testing"
)

func Test_handleExecutionContainer(t *testing.T) {
execution := &driver.Execution{}

executionPaths := &ExecutionPaths{
Condition: "condition.txt",
}

err := handleExecution(execution, CONTAINER, executionPaths)

if err != nil {
t.Errorf("Unexpected error: %v", err)
}

verifyFileContent(t, executionPaths.Condition, "nil")

cleanup(t, executionPaths)
}

func Test_handleExecutionRootDAG(t *testing.T) {
execution := &driver.Execution{}

executionPaths := &ExecutionPaths{
IterationCount: "iteration_count.txt",
Condition: "condition.txt",
}

err := handleExecution(execution, ROOT_DAG, executionPaths)

if err != nil {
t.Errorf("Unexpected error: %v", err)
}

verifyFileContent(t, executionPaths.IterationCount, "0")
verifyFileContent(t, executionPaths.Condition, "nil")

cleanup(t, executionPaths)
}

func cleanup(t *testing.T, executionPaths *ExecutionPaths) {
removeIfExists(t, executionPaths.IterationCount)
removeIfExists(t, executionPaths.ExecutionID)
removeIfExists(t, executionPaths.Condition)
removeIfExists(t, executionPaths.PodSpecPatch)
removeIfExists(t, executionPaths.CachedDecision)
}

func removeIfExists(t *testing.T, filePath string) {
_, err := os.Stat(filePath)
if err == nil {
err = os.Remove(filePath)
if err != nil {
t.Errorf("Unexpected error while removing the created file: %v", err)
}
}
}

func verifyFileContent(t *testing.T, filePath string, expectedContent string) {
_, err := os.Stat(filePath)
if os.IsNotExist(err) {
t.Errorf("Expected file %s to be created, but it doesn't exist", filePath)
}

fileContent, err := os.ReadFile(filePath)
if err != nil {
t.Errorf("Failed to read file contents: %v", err)
}

if string(fileContent) != expectedContent {
t.Errorf("Expected file fileContent to be %q, got %q", expectedContent, string(fileContent))
}
}

0 comments on commit 90a92f9

Please sign in to comment.