diff --git a/flyteadmin/pkg/manager/impl/validation/task_execution_validator_test.go b/flyteadmin/pkg/manager/impl/validation/task_execution_validator_test.go index 26f2c48948..712d5948b0 100644 --- a/flyteadmin/pkg/manager/impl/validation/task_execution_validator_test.go +++ b/flyteadmin/pkg/manager/impl/validation/task_execution_validator_test.go @@ -63,28 +63,6 @@ func TestValidateTaskExecutionRequest_MissingFields(t *testing.T) { }, maxOutputSizeInBytes) assert.EqualError(t, err, "missing occurred_at") - err = ValidateTaskExecutionRequest(&admin.TaskExecutionEventRequest{ - Event: &event.TaskExecutionEvent{ - OccurredAt: taskEventOccurredAtProto, - TaskId: &core.Identifier{ - ResourceType: core.ResourceType_TASK, - Project: "project", - Domain: "domain", - Name: "name", - }, - ParentNodeExecutionId: &core.NodeExecutionIdentifier{ - NodeId: "nodey", - ExecutionId: &core.WorkflowExecutionIdentifier{ - Project: "project", - Domain: "domain", - Name: "name", - }, - }, - RetryAttempt: 0, - }, - }, maxOutputSizeInBytes) - assert.EqualError(t, err, "missing version") - err = ValidateTaskExecutionRequest(&admin.TaskExecutionEventRequest{ Event: &event.TaskExecutionEvent{ OccurredAt: taskEventOccurredAtProto, diff --git a/flyteadmin/pkg/manager/impl/validation/task_validator_test.go b/flyteadmin/pkg/manager/impl/validation/task_validator_test.go index be450e4171..d6bc780bbb 100644 --- a/flyteadmin/pkg/manager/impl/validation/task_validator_test.go +++ b/flyteadmin/pkg/manager/impl/validation/task_validator_test.go @@ -110,14 +110,6 @@ func TestValidateTaskEmptyName(t *testing.T) { assert.EqualError(t, err, "missing name") } -func TestValidateTaskEmptyVersion(t *testing.T) { - request := testutils.GetValidTaskRequest() - request.Id.Version = "" - err := ValidateTask(context.Background(), request, testutils.GetRepoWithDefaultProject(), - getMockTaskResources(), mockWhitelistConfigProvider, taskApplicationConfigProvider) - assert.EqualError(t, err, "missing version") -} - func TestValidateTaskEmptyType(t *testing.T) { request := testutils.GetValidTaskRequest() request.Spec.Template.Type = "" diff --git a/flyteadmin/pkg/manager/impl/validation/validation.go b/flyteadmin/pkg/manager/impl/validation/validation.go index de2927495c..2ff5859b44 100644 --- a/flyteadmin/pkg/manager/impl/validation/validation.go +++ b/flyteadmin/pkg/manager/impl/validation/validation.go @@ -95,6 +95,23 @@ func ValidateIdentifierFieldsSet(id *core.Identifier) error { return nil } +// ValidateTaskIdentifierFieldsSet Validates that all required fields, except version, for a task identifier are present. +func ValidateTaskIdentifierFieldsSet(id *core.Identifier) error { + if id == nil { + return shared.GetMissingArgumentError(shared.ID) + } + if err := ValidateEmptyStringField(id.Project, shared.Project); err != nil { + return err + } + if err := ValidateEmptyStringField(id.Domain, shared.Domain); err != nil { + return err + } + if err := ValidateEmptyStringField(id.Name, shared.Name); err != nil { + return err + } + return nil +} + // ValidateIdentifier Validates that all required fields for an identifier are present. func ValidateIdentifier(id *core.Identifier, expectedType common.Entity) error { if id == nil { @@ -105,6 +122,9 @@ func ValidateIdentifier(id *core.Identifier, expectedType common.Entity) error { "unexpected resource type %s for identifier [%+v], expected %s instead", strings.ToLower(id.ResourceType.String()), id, strings.ToLower(entityToResourceType[expectedType].String())) } + if id.ResourceType == core.ResourceType_TASK { + return ValidateTaskIdentifierFieldsSet(id) + } return ValidateIdentifierFieldsSet(id) } diff --git a/flyteadmin/pkg/repositories/gormimpl/task_repo.go b/flyteadmin/pkg/repositories/gormimpl/task_repo.go index 1b42756b7a..82cbd31f40 100644 --- a/flyteadmin/pkg/repositories/gormimpl/task_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/task_repo.go @@ -49,14 +49,22 @@ func (r *TaskRepo) Create(ctx context.Context, input models.Task, descriptionEnt func (r *TaskRepo) Get(ctx context.Context, input interfaces.Identifier) (models.Task, error) { var task models.Task timer := r.metrics.GetDuration.Start() - tx := r.db.WithContext(ctx).Where(&models.Task{ - TaskKey: models.TaskKey{ - Project: input.Project, - Domain: input.Domain, - Name: input.Name, - Version: input.Version, - }, - }).Take(&task) + var tx *gorm.DB + if input.Version == "" { + tx = r.db.WithContext(ctx).Where(`"tasks"."project" = ? AND "tasks"."domain" = ? AND "tasks"."name" = ?`, input.Project, input.Domain, input.Name).Limit(1) + tx = tx.Order(`"tasks"."version" DESC`) + tx.Find(&task) + } else { + tx = r.db.WithContext(ctx).Where(&models.Task{ + TaskKey: models.TaskKey{ + Project: input.Project, + Domain: input.Domain, + Name: input.Name, + Version: input.Version, + }, + }).Take(&task) + } + timer.Stop() if errors.Is(tx.Error, gorm.ErrRecordNotFound) { return models.Task{}, flyteAdminDbErrors.GetMissingEntityError(core.ResourceType_TASK.String(), &core.Identifier{ diff --git a/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go index 3309ad3609..643dc66e4d 100644 --- a/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go @@ -79,6 +79,28 @@ func TestGetTask(t *testing.T) { assert.Equal(t, version, output.Version) assert.Equal(t, []byte{1, 2}, output.Closure) assert.Equal(t, pythonTestTaskType, output.Type) + + //When version is empty, return the latest task + GlobalMock = mocket.Catcher.Reset() + GlobalMock.Logging = true + + GlobalMock.NewMock().WithQuery( + `SELECT * FROM "tasks" WHERE "tasks"."project" = $1 AND "tasks"."domain" = $2 AND "tasks"."name" = $3 ORDER BY "tasks"."version" DESC LIMIT 1`). + WithReply(tasks) + output, err = taskRepo.Get(context.Background(), interfaces.Identifier{ + Project: project, + Domain: domain, + Name: name, + Version: "", + }) + + assert.NoError(t, err) + assert.Equal(t, project, output.Project) + assert.Equal(t, domain, output.Domain) + assert.Equal(t, name, output.Name) + assert.Equal(t, version, output.Version) + assert.Equal(t, []byte{1, 2}, output.Closure) + assert.Equal(t, pythonTestTaskType, output.Type) } func TestListTasks(t *testing.T) { diff --git a/flyteadmin/tests/task_execution_test.go b/flyteadmin/tests/task_execution_test.go index e380104684..c333bd7fd5 100644 --- a/flyteadmin/tests/task_execution_test.go +++ b/flyteadmin/tests/task_execution_test.go @@ -45,6 +45,37 @@ var taskExecutionIdentifier = &core.TaskExecutionIdentifier{ RetryAttempt: 1, } +func TestGetTaskExecutions(t *testing.T) { + truncateAllTablesForTestingOnly() + populateWorkflowExecutionsForTestingOnly() + + ctx := context.Background() + client, conn := GetTestAdminServiceClient() + defer conn.Close() + + _, err := client.CreateTask(ctx, &admin.TaskCreateRequest{ + Id: taskIdentifier, + Spec: testutils.GetValidTaskRequest().Spec, + }) + require.NoError(t, err) + + resp, err := client.GetTask(ctx, &admin.ObjectGetRequest{ + Id: &core.Identifier{ + ResourceType: core.ResourceType_TASK, + Project: project, + Domain: "development", + Name: "task name", + }, + }) + + assert.Nil(t, err) + assert.Equal(t, resp.Id.Project, project) + assert.Equal(t, resp.Id.Domain, "development") + assert.Equal(t, resp.Id.Name, "task name") + assert.Equal(t, resp.Id.Version, "task version") + +} + func createTaskAndNodeExecution( ctx context.Context, t *testing.T, client service.AdminServiceClient, conn *grpc.ClientConn, occurredAtProto *timestamp.Timestamp) {