Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reference task to fetch latest version by default #5895

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,28 +63,6 @@ func TestValidateTaskExecutionRequest_MissingFields(t *testing.T) {
}, maxOutputSizeInBytes)
assert.EqualError(t, err, "missing occurred_at")

err = ValidateTaskExecutionRequest(&admin.TaskExecutionEventRequest{
Event: &event.TaskExecutionEvent{
OccurredAt: taskEventOccurredAtProto,
TaskId: &core.Identifier{
ResourceType: core.ResourceType_TASK,
Project: "project",
Domain: "domain",
Name: "name",
},
ParentNodeExecutionId: &core.NodeExecutionIdentifier{
NodeId: "nodey",
ExecutionId: &core.WorkflowExecutionIdentifier{
Project: "project",
Domain: "domain",
Name: "name",
},
},
RetryAttempt: 0,
},
}, maxOutputSizeInBytes)
assert.EqualError(t, err, "missing version")

err = ValidateTaskExecutionRequest(&admin.TaskExecutionEventRequest{
Event: &event.TaskExecutionEvent{
OccurredAt: taskEventOccurredAtProto,
Expand Down
8 changes: 0 additions & 8 deletions flyteadmin/pkg/manager/impl/validation/task_validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,6 @@ func TestValidateTaskEmptyName(t *testing.T) {
assert.EqualError(t, err, "missing name")
}

func TestValidateTaskEmptyVersion(t *testing.T) {
request := testutils.GetValidTaskRequest()
request.Id.Version = ""
err := ValidateTask(context.Background(), request, testutils.GetRepoWithDefaultProject(),
getMockTaskResources(), mockWhitelistConfigProvider, taskApplicationConfigProvider)
assert.EqualError(t, err, "missing version")
}

func TestValidateTaskEmptyType(t *testing.T) {
request := testutils.GetValidTaskRequest()
request.Spec.Template.Type = ""
Expand Down
20 changes: 20 additions & 0 deletions flyteadmin/pkg/manager/impl/validation/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,23 @@
return nil
}

// ValidateTaskIdentifierFieldsSet Validates that all required fields, except version, for a task identifier are present.
func ValidateTaskIdentifierFieldsSet(id *core.Identifier) error {
if id == nil {
return shared.GetMissingArgumentError(shared.ID)
}

Check warning on line 102 in flyteadmin/pkg/manager/impl/validation/validation.go

View check run for this annotation

Codecov / codecov/patch

flyteadmin/pkg/manager/impl/validation/validation.go#L101-L102

Added lines #L101 - L102 were not covered by tests
if err := ValidateEmptyStringField(id.Project, shared.Project); err != nil {
return err
}
if err := ValidateEmptyStringField(id.Domain, shared.Domain); err != nil {
return err
}
if err := ValidateEmptyStringField(id.Name, shared.Name); err != nil {
return err
}
return nil
}

// ValidateIdentifier Validates that all required fields for an identifier are present.
func ValidateIdentifier(id *core.Identifier, expectedType common.Entity) error {
if id == nil {
Expand All @@ -105,6 +122,9 @@
"unexpected resource type %s for identifier [%+v], expected %s instead",
strings.ToLower(id.ResourceType.String()), id, strings.ToLower(entityToResourceType[expectedType].String()))
}
if id.ResourceType == core.ResourceType_TASK {
return ValidateTaskIdentifierFieldsSet(id)
}
return ValidateIdentifierFieldsSet(id)
}

Expand Down
24 changes: 16 additions & 8 deletions flyteadmin/pkg/repositories/gormimpl/task_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,22 @@ func (r *TaskRepo) Create(ctx context.Context, input models.Task, descriptionEnt
func (r *TaskRepo) Get(ctx context.Context, input interfaces.Identifier) (models.Task, error) {
var task models.Task
timer := r.metrics.GetDuration.Start()
tx := r.db.WithContext(ctx).Where(&models.Task{
TaskKey: models.TaskKey{
Project: input.Project,
Domain: input.Domain,
Name: input.Name,
Version: input.Version,
},
}).Take(&task)
var tx *gorm.DB
if input.Version == "" {
tx = r.db.WithContext(ctx).Where(`"tasks"."project" = ? AND "tasks"."domain" = ? AND "tasks"."name" = ?`, input.Project, input.Domain, input.Name).Limit(1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, are we able to use TaskKey here?

models.TaskKey{
  Project: input.Project,
  Domain:  input.Domain,
  Name:    input.Name,
}

tx = tx.Order(`"tasks"."version" DESC`)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we order by version instead of time information such as created_at?

tx.Find(&task)
} else {
tx = r.db.WithContext(ctx).Where(&models.Task{
TaskKey: models.TaskKey{
Project: input.Project,
Domain: input.Domain,
Name: input.Name,
Version: input.Version,
},
}).Take(&task)
}

timer.Stop()
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
return models.Task{}, flyteAdminDbErrors.GetMissingEntityError(core.ResourceType_TASK.String(), &core.Identifier{
Expand Down
22 changes: 22 additions & 0 deletions flyteadmin/pkg/repositories/gormimpl/task_repo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,28 @@ func TestGetTask(t *testing.T) {
assert.Equal(t, version, output.Version)
assert.Equal(t, []byte{1, 2}, output.Closure)
assert.Equal(t, pythonTestTaskType, output.Type)

//When version is empty, return the latest task
GlobalMock = mocket.Catcher.Reset()
GlobalMock.Logging = true

GlobalMock.NewMock().WithQuery(
`SELECT * FROM "tasks" WHERE "tasks"."project" = $1 AND "tasks"."domain" = $2 AND "tasks"."name" = $3 ORDER BY "tasks"."version" DESC LIMIT 1`).
WithReply(tasks)
output, err = taskRepo.Get(context.Background(), interfaces.Identifier{
Project: project,
Domain: domain,
Name: name,
Version: "",
})

assert.NoError(t, err)
assert.Equal(t, project, output.Project)
assert.Equal(t, domain, output.Domain)
assert.Equal(t, name, output.Name)
assert.Equal(t, version, output.Version)
assert.Equal(t, []byte{1, 2}, output.Closure)
assert.Equal(t, pythonTestTaskType, output.Type)
}

func TestListTasks(t *testing.T) {
Expand Down
31 changes: 31 additions & 0 deletions flyteadmin/tests/task_execution_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,37 @@ var taskExecutionIdentifier = &core.TaskExecutionIdentifier{
RetryAttempt: 1,
}

func TestGetTaskExecutions(t *testing.T) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we register more than one tasks and see if we get the latest one?

truncateAllTablesForTestingOnly()
populateWorkflowExecutionsForTestingOnly()

ctx := context.Background()
client, conn := GetTestAdminServiceClient()
defer conn.Close()

_, err := client.CreateTask(ctx, &admin.TaskCreateRequest{
Id: taskIdentifier,
Spec: testutils.GetValidTaskRequest().Spec,
})
require.NoError(t, err)

resp, err := client.GetTask(ctx, &admin.ObjectGetRequest{
Id: &core.Identifier{
ResourceType: core.ResourceType_TASK,
Project: project,
Domain: "development",
Name: "task name",
},
})

assert.Nil(t, err)
assert.Equal(t, resp.Id.Project, project)
assert.Equal(t, resp.Id.Domain, "development")
assert.Equal(t, resp.Id.Name, "task name")
assert.Equal(t, resp.Id.Version, "task version")

}

func createTaskAndNodeExecution(
ctx context.Context, t *testing.T, client service.AdminServiceClient, conn *grpc.ClientConn,
occurredAtProto *timestamp.Timestamp) {
Expand Down
Loading