Skip to content

Commit

Permalink
Merge pull request #143 from PrefectHQ/new-routes
Browse files Browse the repository at this point in the history
2 New GraphQL Routes
  • Loading branch information
cicdw authored Nov 30, 2020
2 parents a3626c8 + 8b13561 commit 1e7f72b
Show file tree
Hide file tree
Showing 6 changed files with 322 additions and 291 deletions.
2 changes: 2 additions & 0 deletions changes/pr143.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
enhancement:
- "Add two new GraphQL routes for Core functionality - [#143](https://github.com/PrefectHQ/server/pull/143)"
111 changes: 59 additions & 52 deletions src/prefect_server/api/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,67 +268,74 @@ async def get_or_create_task_run(
raise ValueError("Invalid ID")


@register_api("runs.get_or_create_mapped_task_run_children")
async def get_or_create_mapped_task_run_children(
flow_run_id: str, task_id: str, max_map_index: int
) -> List[str]:
@register_api("runs.get_or_create_task_run_info")
async def get_or_create_task_run_info(
flow_run_id: str, task_id: str, map_index: int = None
) -> dict:
"""
Creates and/or retrieves mapped child task runs for a given flow run and task.
Given a flow_run_id, task_id, and map_index, return details about the corresponding task run.
If the task run doesn't exist, it will be created.
Args:
- flow_run_id (str): the flow run associated with the parent task run
- task_id (str): the task ID to create and/or retrieve
- max_map_index (int,): the number of mapped children e.g., a value of 2 yields 3 mapped children
Returns:
- dict: a dict of details about the task run, including its id, version, and state.
"""
# grab task info
task = await models.Task.where(id=task_id).first({"cache_key", "tenant_id"})
# generate task runs to upsert
task_runs = [
models.TaskRun(
tenant_id=task.tenant_id,
flow_run_id=flow_run_id,
task_id=task_id,
map_index=i,
cache_key=task.cache_key,
)
for i in range(max_map_index + 1)
]
# upsert the mapped children
task_runs = (
await models.TaskRun().insert_many(
objects=task_runs,
on_conflict=dict(
constraint="task_run_unique_identifier_key",
update_columns=["cache_key"],
),
selection_set={"returning": {"id", "map_index"}},
)
)["returning"]
task_runs.sort(key=lambda task_run: task_run.map_index)
# get task runs without states
stateless_runs = await models.TaskRun.where(

if map_index is None:
map_index = -1

task_run = await models.TaskRun.where(
{
"flow_run_id": {"_eq": flow_run_id},
"task_id": {"_eq": task_id},
# this syntax indicates "where there are no states"
"_not": {"states": {}},
"map_index": {"_eq": map_index},
}
).get({"id", "map_index", "version"})
# create and insert states for stateless task runs
task_run_states = [
models.TaskRunState(
tenant_id=task.tenant_id,
task_run_id=task_run.id,
**models.TaskRunState.fields_from_state(
Pending(message="Task run created")
),
).first({"id", "version", "state", "serialized_state"})

if task_run:
return dict(
id=task_run.id,
version=task_run.version,
state=task_run.state,
serialized_state=task_run.serialized_state,
)
for task_run in stateless_runs
]
await models.TaskRunState().insert_many(task_run_states)

# return the task run ids
return [task_run.id for task_run in task_runs]
# if it isn't found, add it to the DB
task = await models.Task.where(id=task_id).first({"cache_key", "tenant_id"})
if not task:
raise ValueError("Invalid task ID")

db_task_run = models.TaskRun(
tenant_id=task.tenant_id,
flow_run_id=flow_run_id,
task_id=task_id,
map_index=map_index,
cache_key=task.cache_key,
version=0,
)

db_task_run_state = models.TaskRunState(
tenant_id=task.tenant_id,
state="Pending",
timestamp=pendulum.now(),
message="Task run created",
serialized_state=Pending(message="Task run created").serialize(),
)

db_task_run.states = [db_task_run_state]
run = await db_task_run.insert(
on_conflict=dict(
constraint="task_run_unique_identifier_key",
update_columns=["cache_key"],
),
selection_set={"returning": {"id"}},
)

return dict(
id=run.returning.id,
version=db_task_run.version,
state="Pending",
serialized_state=db_task_run.serialized_state,
)


@register_api("runs.update_flow_run_heartbeat")
Expand Down
52 changes: 39 additions & 13 deletions src/prefect_server/graphql/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,51 @@

import prefect
from prefect import api
from prefect_server.database import postgres
from prefect_server.database import models, postgres
from prefect_server.utilities import context
from prefect_server.utilities.graphql import mutation, query

state_schema = prefect.serialization.state.StateSchema()


@query.field("get_task_run_info")
async def resolve_get_task_run_info(
obj: Any, info: GraphQLResolveInfo, task_run_id: str
) -> dict:
"""
Retrieve details about a task run.
"""
task_run = await models.TaskRun.where(id=task_run_id).first(
{"version", "serialized_state", "state"}
)
if not task_run:
raise ValueError("Invalid task run ID")

return {
"version": task_run.version,
"serialized_state": task_run.serialized_state,
"state": task_run.state,
"id": task_run_id,
}


@mutation.field("get_or_create_task_run_info")
async def resolve_get_or_create_task_run_info(
obj: Any, info: GraphQLResolveInfo, input: dict
) -> dict:
info = await api.runs.get_or_create_task_run_info(
flow_run_id=input["flow_run_id"],
task_id=input["task_id"],
map_index=input.get("map_index"),
)
return {
"id": info["id"],
"version": info["version"],
"state": info["state"],
"serialized_state": info["serialized_state"],
}


@query.field("mapped_children")
async def resolve_mapped_children(
obj: Any, info: GraphQLResolveInfo, task_run_id: str
Expand Down Expand Up @@ -131,18 +169,6 @@ async def resolve_get_or_create_task_run(
}


@mutation.field("get_or_create_mapped_task_run_children")
async def resolve_get_or_create_mapped_task_run_children(
obj: Any, info: GraphQLResolveInfo, input: dict
) -> List[dict]:
task_runs = await api.runs.get_or_create_mapped_task_run_children(
flow_run_id=input["flow_run_id"],
task_id=input["task_id"],
max_map_index=input["max_map_index"],
)
return {"ids": task_runs}


@mutation.field("delete_flow_run")
async def resolve_delete_flow_run(
obj: Any, info: GraphQLResolveInfo, input: dict
Expand Down
39 changes: 32 additions & 7 deletions src/prefect_server/graphql/schema/runs.graphql
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
extend type Query {
mapped_children(task_run_id: UUID!): mapped_children_payload
mapped_children(task_run_id: UUID!): mapped_children_payload

"""
Given a task run ID, retrieve current task run info
"""
get_task_run_info(
task_run_id: UUID!
): task_run_info_payload

}

Expand All @@ -26,10 +33,12 @@ extend type Mutation {
input: get_or_create_task_run_input!
): task_run_id_payload

"Gets or creates all mapped task run children for a parent task run."
get_or_create_mapped_task_run_children(
input: get_or_create_mapped_task_run_children_input!
): get_or_create_mapped_task_run_children_payload
"""
Given a flow run, task, and map index, retrieve the corresponding task run id
"""
get_or_create_task_run_info(
input: get_or_create_task_run_info_input!
): get_or_create_task_run_info_payload

"Update a flow run's heartbeat. This indicates the flow run is alive and is called automatically by Prefect Core."
update_flow_run_heartbeat(
Expand Down Expand Up @@ -95,6 +104,12 @@ input get_or_create_task_run_input {
map_index: Int
}

input get_or_create_task_run_info_input {
flow_run_id: UUID!
task_id: UUID!
map_index: Int
}

input get_or_create_mapped_task_run_children_input {
flow_run_id: UUID!
task_id: UUID!
Expand Down Expand Up @@ -127,8 +142,18 @@ type task_run_id_payload {
id: UUID
}

type get_or_create_mapped_task_run_children_payload {
ids: [UUID!]
type task_run_info_payload {
id: UUID
version: Int
serialized_state: JSON
state: String
}

type get_or_create_task_run_info_payload {
id: UUID
version: Int
state: String
serialized_state: JSON
}

type runs_in_queue_payload {
Expand Down
Loading

0 comments on commit 1e7f72b

Please sign in to comment.